Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| from datasets import load_from_disk | |
| from itertools import chain | |
| import operator | |
| pd.options.plotting.backend = "plotly" | |
| TITLE = "Identity Biases in Diffusion Models: Professions" | |
| _INTRO = """ | |
| # Identity Biases in Diffusion Models: Professions | |
| Explore profession-level social biases in the data from [DiffusionBiasExplorer](https://hf.co/spaces/tti-bias/diffusion-bias-explorer)! | |
| This demo leverages the gender and ethnicity representation clusters described in the [companion app](https://hf.co/spaces/tti-bias/diffusion-face-clustering) | |
| to analyze social trends in machine-generated visual representations of professions. | |
| The **Professions Overview** tab lets you compare the distribution over | |
| [identity clusters](https://hf.co/spaces/tti-bias/diffusion-face-clustering "Identity clusters identify visual features in the systems' output space correlated with variation of gender and ethnicity in input prompts.") | |
| across professions for Stable Diffusion and Dalle-2 systems (or aggregated for `All Models`). | |
| The **Professions Focus** tab provides more details for each of the individual professions, including direct system comparisons and examples of profession images for each cluster. | |
| This work was done in the scope of the [Stable Bias Project](https://hf.co/spaces/tti-bias/stable-bias). | |
| """ | |
| _ = """ | |
| For example, you can use this tool to investigate: | |
| - How do each model's representation of professions correlate with the gender ratios reported by the [U.S. Bureau of Labor | |
| Statistics](https://www.bls.gov/cps/cpsaat11.htm "The reported percentage of women in each profession in the US is indicated in the `Labor Women` column in the Professions Overview tab.")? | |
| Are social trends reflected, are they exaggerated? | |
| - Which professions have the starkest differences in how different models represent them? | |
| """ | |
| professions_dset = load_from_disk("professions") | |
| professions_df = professions_dset.to_pandas() | |
| clusters_dicts = dict( | |
| (num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json"))) | |
| for num_cl in [12, 24, 48] | |
| ) | |
| cluster_summaries_by_size = json.load(open("clusters/cluster_summaries_by_size.json")) | |
| prompts = pd.read_csv("promptsadjectives.csv") | |
| professions = ["all professions"] + list( | |
| # sorted([p.lower() for p in prompts["Occupation-Noun"].tolist()]) | |
| sorted([p for p in prompts["Occupation-Noun"].tolist()]) | |
| ) | |
| models = { | |
| "All": "All Models", | |
| "SD_14": "Stable Diffusion 1.4", | |
| "SD_2": "Stable Diffusion 2", | |
| "DallE": "Dall-E 2", | |
| } | |
| df_models = { | |
| "All Models": "All", | |
| "Stable Diffusion 1.4": "SD_14", | |
| "Stable Diffusion 2": "SD_2", | |
| "Dall-E 2": "DallE", | |
| } | |
| def describe_cluster(num_clusters, block="label"): | |
| cl_dict = clusters_dicts[num_clusters] | |
| labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1)) | |
| labels_values.reverse() | |
| total = float(sum(cl_dict.values())) | |
| lv_prcnt = list( | |
| (item[0], round(item[1] * 100 / total, 0)) for item in labels_values | |
| ) | |
| top_label = lv_prcnt[0][0] | |
| description_string = ( | |
| "<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>" | |
| % (to_string(block), to_string(top_label), lv_prcnt[0][1]) | |
| ) | |
| description_string += "<p>This is followed by: " | |
| for lv in lv_prcnt[1:]: | |
| description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1]) | |
| description_string += "</p>" | |
| return description_string | |
| def make_profession_plot(num_clusters, prof_name): | |
| sorted_cl_scores = [ | |
| (k, v) | |
| for k, v in sorted( | |
| clusters_dicts[num_clusters]["All"][prof_name][ | |
| "cluster_proportions" | |
| ].items(), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| if v > 0 | |
| ] | |
| pre_pandas = dict( | |
| [ | |
| ( | |
| models[mod_name], | |
| dict( | |
| ( | |
| f"Cluster {k}", | |
| clusters_dicts[num_clusters][mod_name][prof_name][ | |
| "cluster_proportions" | |
| ][k], | |
| ) | |
| for k, _ in sorted_cl_scores | |
| ), | |
| ) | |
| for mod_name in models | |
| ] | |
| ) | |
| df = pd.DataFrame.from_dict(pre_pandas) | |
| prof_plot = df.plot(kind="bar", barmode="group") | |
| cl_summary_text = f"Profession '{prof_name}':\n" | |
| for cl_id, _ in sorted_cl_scores: | |
| cl_summary_text += f"- {cluster_summaries_by_size[str(num_clusters)][int(cl_id)].replace(' gender terms', '').replace('; ethnicity terms:', ',')} \n" | |
| return ( | |
| prof_plot, | |
| gr.update( | |
| choices=[k for k, _ in sorted_cl_scores], value=sorted_cl_scores[0][0] | |
| ), | |
| gr.update(value=cl_summary_text), | |
| ) | |
| def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8): | |
| professions_list_clusters = [ | |
| ( | |
| prof_name, | |
| clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
| "cluster_proportions" | |
| ], | |
| ) | |
| for prof_name in prof_names | |
| ] | |
| totals = sorted( | |
| [ | |
| ( | |
| k, | |
| sum( | |
| prof_clusters[str(k)] | |
| for _, prof_clusters in professions_list_clusters | |
| ), | |
| ) | |
| for k in range(num_clusters) | |
| ], | |
| key=lambda x: x[1], | |
| reverse=True, | |
| )[:max_cols] | |
| prof_list_pre_pandas = [ | |
| dict( | |
| [ | |
| ("Profession", prof_name), | |
| ( | |
| "Entropy", | |
| clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
| "entropy" | |
| ], | |
| ), | |
| ( | |
| "Labor Women", | |
| clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
| "labor_fm" | |
| ][0], | |
| ), | |
| ("", ""), | |
| ] | |
| + [(f"Cluster {k}", prof_clusters[str(k)]) for k, v in totals if v > 0] | |
| ) | |
| for prof_name, prof_clusters in professions_list_clusters | |
| ] | |
| clusters_df = pd.DataFrame.from_dict(prof_list_pre_pandas) | |
| cl_summary_text = "" | |
| for cl_id, _ in totals[:max_cols]: | |
| cl_summary_text += f"- {cluster_summaries_by_size[str(num_clusters)][cl_id].replace(' gender terms', '').replace('; ethnicity terms:', ',')} \n" | |
| return ( | |
| [c[0] for c in totals], | |
| ( | |
| clusters_df.style.background_gradient( | |
| axis=None, vmin=0, vmax=100, cmap="YlGnBu" | |
| ) | |
| .format(precision=1) | |
| .to_html() | |
| ), | |
| gr.update(value=cl_summary_text), | |
| ) | |
| def get_image(model, fname, score): | |
| return ( | |
| professions_dset.select( | |
| professions_df[ | |
| (professions_df["image_path"] == fname) | |
| & (professions_df["model"] == model) | |
| ].index | |
| )["image"][0], | |
| " ".join(fname.split("/")[0].split("_")[4:]) | |
| + f" | {score:.2f}" | |
| + f" | {models[model]}", | |
| ) | |
| def show_examplars(num_clusters, prof_name, cl_id, confidence_threshold=0.6): | |
| # only show images where the similarity to the centroid is > confidence_threshold | |
| examplars_dict = clusters_dicts[num_clusters]["All"][prof_name][ | |
| "cluster_examplars" | |
| ][str(cl_id)] | |
| l = [ | |
| tuple(img) | |
| for img in examplars_dict["close"] | |
| + examplars_dict["mid"][:2] | |
| + examplars_dict["far"] | |
| ] | |
| l = [ | |
| img | |
| for i, img in enumerate(l) | |
| if img[0] > confidence_threshold and img not in l[:i] | |
| ] | |
| return ( | |
| [get_image(model, fname, score) for score, model, fname in l], | |
| gr.update( | |
| label=f"Generations for profession ''{prof_name}'' assigned to cluster {cl_id} of {num_clusters}" | |
| ), | |
| ) | |
| with gr.Blocks(title=TITLE) as demo: | |
| gr.Markdown(_INTRO) | |
| gr.HTML( | |
| """<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image systems and may depict offensive stereotypes or contain explicit content.</span>""" | |
| ) | |
| with gr.Tab("Professions Overview"): | |
| gr.Markdown( | |
| """ | |
| Select one or more professions and models from the dropdowns on the left to see which clusters are most representative for this combination. | |
| Try choosing different numbers of clusters to see if the results change, and then go to the 'Profession Focus' tab to go more in-depth into these results. | |
| The `Labor Women` column provided for comparison corresponds to the gender ratio reported by the | |
| [U.S. Bureau of Labor Statistics](https://www.bls.gov/cps/cpsaat11.htm) for each profession. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("Select the parameters here:") | |
| num_clusters = gr.Radio( | |
| [12, 24, 48], | |
| value=12, | |
| label="How many clusters do you want to use to represent identities?", | |
| ) | |
| model_choices = gr.Dropdown( | |
| [ | |
| "All Models", | |
| "Stable Diffusion 1.4", | |
| "Stable Diffusion 2", | |
| "Dall-E 2", | |
| ], | |
| value="All Models", | |
| label="Which models do you want to compare?", | |
| interactive=True, | |
| ) | |
| profession_choices_overview = gr.Dropdown( | |
| professions, | |
| value=[ | |
| "all professions", | |
| "CEO", | |
| "director", | |
| "social assistant", | |
| "social worker", | |
| ], | |
| label="Which professions do you want to compare?", | |
| multiselect=True, | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| table = gr.HTML( | |
| label="Profession assignment per cluster", wrap=True | |
| ) | |
| with gr.Row(): | |
| # clusters = gr.Dataframe(type="array", visible=False, col_count=1) | |
| clusters = gr.Textbox(label="clusters", visible=False) | |
| gr.Markdown( | |
| """ | |
| ##### What do the clusters mean? | |
| Below is a summary of the identity cluster compositions. | |
| For more details, see the [companion demo](https://huggingface.co/spaces/tti-bias/DiffusionFaceClustering): | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Accordion(label="Cluster summaries", open=True): | |
| cluster_descriptions_table = gr.Text( | |
| "TODO", label="Cluster summaries", show_label=False | |
| ) | |
| with gr.Tab("Profession Focus"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| "Select a profession to visualize and see which clusters and identity groups are most represented in the profession, as well as some examples of generated images below." | |
| ) | |
| profession_choice_focus = gr.Dropdown( | |
| choices=professions, | |
| value="scientist", | |
| label="Select profession:", | |
| ) | |
| num_clusters_focus = gr.Radio( | |
| [12, 24, 48], | |
| value=12, | |
| label="How many clusters do you want to use to represent identities?", | |
| ) | |
| with gr.Column(): | |
| plot = gr.Plot( | |
| label=f"Makeup of the cluster assignments for profession {profession_choice_focus}" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| ##### What do the clusters mean? | |
| Below is a summary of the identity cluster compositions. | |
| For more details, see the [companion demo](https://huggingface.co/spaces/tti-bias/DiffusionFaceClustering): | |
| """ | |
| ) | |
| with gr.Accordion(label="Cluster summaries", open=True): | |
| cluster_descriptions = gr.Text( | |
| "TODO", label="Cluster summaries", show_label=False | |
| ) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| ##### What's in the clusters? | |
| You can show examples of profession images assigned to each identity cluster by selecting one here: | |
| """ | |
| ) | |
| with gr.Accordion(label="Cluster selection", open=True): | |
| cluster_id_focus = gr.Dropdown( | |
| choices=[i for i in range(num_clusters_focus.value)], | |
| value=0, | |
| label="Select cluster to visualize:", | |
| ) | |
| with gr.Row(): | |
| examplars_plot = gr.Gallery( | |
| label="Profession images assigned to the selected cluster." | |
| ).style(grid=4, height="auto", container=True) | |
| demo.load( | |
| make_profession_table, | |
| [num_clusters, profession_choices_overview, model_choices], | |
| [clusters, table, cluster_descriptions_table], | |
| queue=False, | |
| ) | |
| demo.load( | |
| make_profession_plot, | |
| [num_clusters_focus, profession_choice_focus], | |
| [plot, cluster_id_focus, cluster_descriptions], | |
| queue=False, | |
| ) | |
| demo.load( | |
| show_examplars, | |
| [ | |
| num_clusters_focus, | |
| profession_choice_focus, | |
| cluster_id_focus, | |
| ], | |
| [examplars_plot, examplars_plot], | |
| queue=False, | |
| ) | |
| for var in [num_clusters, model_choices, profession_choices_overview]: | |
| var.change( | |
| make_profession_table, | |
| [num_clusters, profession_choices_overview, model_choices], | |
| [clusters, table, cluster_descriptions_table], | |
| queue=False, | |
| ) | |
| for var in [num_clusters_focus, profession_choice_focus]: | |
| var.change( | |
| make_profession_plot, | |
| [num_clusters_focus, profession_choice_focus], | |
| [plot, cluster_id_focus, cluster_descriptions], | |
| queue=False, | |
| ) | |
| for var in [num_clusters_focus, profession_choice_focus, cluster_id_focus]: | |
| var.change( | |
| show_examplars, | |
| [ | |
| num_clusters_focus, | |
| profession_choice_focus, | |
| cluster_id_focus, | |
| ], | |
| [examplars_plot, examplars_plot], | |
| queue=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True) | |