Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| import os, csv | |
| import pandas as pd | |
| import numpy as np | |
| import gradio as gr | |
| prompts=pd.read_csv('promptsadjectives.csv') | |
| masc = prompts['Masc-adj'][:10].tolist() | |
| fem = prompts['Fem-adj'][:10].tolist() | |
| adjectives = sorted(masc+fem) | |
| adjectives.insert(0, '') | |
| occupations = prompts['Occupation-Noun'][:150].tolist() | |
| def get_averages(adj, profession): | |
| if adj != "": | |
| prompt = (adj + ' ' + profession).replace(' ','_') | |
| else: | |
| prompt = profession.replace(' ','_') | |
| #TODO: fix upper/lowercase error | |
| sd14_average = 'facer_faces/SDv14/'+prompt+'.png' | |
| if os.path.isfile(sd14_average) == False: | |
| sd14_average = 'facer_faces/blank.png' | |
| sdv2_average = 'facer_faces/SDv2/'+prompt+'.png' | |
| if os.path.isfile(sdv2_average) == False: | |
| sdv2_average = 'facer_faces/blank.png' | |
| dalle_average = 'facer_faces/dalle2/'+prompt.lower()+'.png' | |
| if os.path.isfile(dalle_average) == False: | |
| dalle_average = 'facer_faces/blank.png' | |
| return((Image.open(sd14_average), "Stable Diffusion v 1.4"), (Image.open(sdv2_average), "Stable Diffusion v 2"), (Image.open(dalle_average), "Dall-E 2")) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Text-to-Image Diffusion Model Average Faces") | |
| gr.Markdown("### We ran 150 professions and 20 adjectives through 3 text-to-image diffusion models to examine what they generate.") | |
| gr.Markdown("#### Choose one of the professions and adjectives from the dropdown menus and see the average face generated by each model.") | |
| gr.HTML("""<span style="color:red">⚠️ <b>DISCLAIMER: the images displayed by this tool are based on images which were generated by text-to-image models which may depict offensive stereotypes or contain explicit content.</b></span>""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| adj = gr.Dropdown(sorted(adjectives, key=str.casefold), value = '',label = "Choose an adjective", interactive= True) | |
| prof = gr.Dropdown(sorted(occupations, key=str.casefold), value = '', label = "Choose a profession", interactive= True) | |
| btn = gr.Button("Get average faces!") | |
| with gr.Column(): | |
| gallery = gr.Gallery( | |
| label="Average images", show_label=False, elem_id="gallery" | |
| ).style(grid=[0,3], height="auto") | |
| gr.Markdown("The three models are: Stable Diffusion v.1.4, Stable Diffusion v.2, and Dall-E 2.") | |
| gr.Markdown("If you see a black square above, we weren't able to compute an average face for this profession!") | |
| btn.click(fn=get_averages, inputs=[adj,prof], outputs=gallery) | |
| demo.launch() | |