Spaces:
Build error
Build error
| import gradio as gr | |
| import requests | |
| import json | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering | |
| from datasets import load_dataset | |
| import datasets | |
| import plotly.io as pio | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from plotly.subplots import make_subplots | |
| import pandas as pd | |
| from sklearn.metrics import confusion_matrix | |
| import importlib | |
| import torch | |
| from dash import Dash, html, dcc | |
| import numpy as np | |
| from sklearn.metrics import accuracy_score | |
| from sklearn.metrics import f1_score | |
| def load_model(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| if model_type == "text_classification": | |
| dataset = load_dataset(dataset_name, config_name) | |
| num_labels = len(dataset["train"].features["label"].names) | |
| if "roberta" in model_name_or_path.lower(): | |
| from transformers import RobertaForSequenceClassification | |
| model = RobertaForSequenceClassification.from_pretrained( | |
| model_name_or_path, num_labels=num_labels) | |
| else: | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name_or_path, num_labels=num_labels) | |
| elif model_type == "token_classification": | |
| dataset = load_dataset(dataset_name, config_name) | |
| num_labels = len( | |
| dataset["train"].features["ner_tags"].feature.names) | |
| model = AutoModelForTokenClassification.from_pretrained( | |
| model_name_or_path, num_labels=num_labels) | |
| elif model_type == "question_answering": | |
| model = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path) | |
| else: | |
| raise ValueError(f"Invalid model type: {model_type}") | |
| return tokenizer, model | |
| def test_model(tokenizer, model, test_data: list, label_map: dict): | |
| results = [] | |
| for text, _, true_label in test_data: | |
| inputs = tokenizer(text, return_tensors="pt", | |
| truncation=True, padding=True) | |
| outputs = model(**inputs) | |
| pred_label = label_map[int(outputs.logits.argmax(dim=-1))] | |
| results.append((text, true_label, pred_label)) | |
| return results | |
| def generate_label_map(dataset): | |
| if "label" not in dataset.features or dataset.features["label"] is None: | |
| return {} | |
| if isinstance(dataset.features["label"], datasets.ClassLabel): | |
| num_labels = dataset.features["label"].num_classes | |
| label_map = {i: label for i, label in enumerate(dataset.features["label"].names)} | |
| else: | |
| num_labels = len(set(dataset["label"])) | |
| label_map = {i: label for i, label in enumerate(set(dataset["label"]))} | |
| return label_map | |
| # Explain fairness score: https://arxiv.org/pdf/1908.09635.pdf | |
| def calculate_fairness_score(results, label_map): | |
| true_labels = [r[1] for r in results] | |
| pred_labels = [r[2] for r in results] | |
| # Overall accuracy | |
| # accuracy = (true_labels == pred_labels).mean() | |
| accuracy = accuracy_score(true_labels, pred_labels) | |
| # Calculate confusion matrix for each group | |
| group_names = label_map.values() | |
| group_cms = {} | |
| for group in group_names: | |
| true_group_indices = [i for i, label in enumerate(true_labels) if label == group] | |
| pred_group_labels = [pred_labels[i] for i in true_group_indices] | |
| true_group_labels = [true_labels[i] for i in true_group_indices] | |
| cm = confusion_matrix(true_group_labels, pred_group_labels, labels=list(group_names)) | |
| group_cms[group] = cm | |
| # Calculate fairness score which means the average difference between confusion matrices | |
| score = 0 | |
| for i, group1 in enumerate(group_names): | |
| for j, group2 in enumerate(group_names): | |
| if i < j: | |
| cm1 = group_cms[group1] | |
| cm2 = group_cms[group2] | |
| diff = np.abs(cm1 - cm2) | |
| score += (diff.sum() / 2) / cm1.sum() | |
| return accuracy, score | |
| # Per-class metrics means the metrics for each class, and the class is defined by the label_map | |
| def calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='accuracy'): | |
| unique_labels = sorted(label_map.values()) | |
| metrics = [] | |
| if metric == 'accuracy': | |
| for label in unique_labels: | |
| label_indices = [i for i, true_label in enumerate(true_labels) if true_label == label] | |
| true_label_subset = [true_labels[i] for i in label_indices] | |
| pred_label_subset = [pred_labels[i] for i in label_indices] | |
| accuracy = accuracy_score(true_label_subset, pred_label_subset) | |
| metrics.append(accuracy) | |
| elif metric == 'f1': | |
| f1_scores = f1_score(true_labels, pred_labels, labels=unique_labels, average=None) | |
| metrics = f1_scores.tolist() | |
| else: | |
| raise ValueError(f"Invalid metric: {metric}") | |
| return metrics | |
| def generate_fairness_statement(accuracy, fairness_score): | |
| accuracy_level = "high" if accuracy >= 0.85 else "moderate" if accuracy >= 0.7 else "low" | |
| fairness_level = "low" if fairness_score <= 0.15 else "moderate" if fairness_score <= 0.3 else "high" | |
| # statement = f"The model has a {accuracy_level} overall accuracy of {accuracy * 100:.2f}% and a {fairness_level} fairness score of {fairness_score:.2f}. " | |
| statement = f"Assessment: " | |
| if fairness_level == "low": | |
| statement += f"The low fairness score ({fairness_score:.2f}) and accuracy ({accuracy * 100:.2f}%) indicate that the model is relatively fair and does not exhibit significant bias across different groups." | |
| elif fairness_level == "moderate": | |
| statement += f"The moderate fairness score ({fairness_score:.2f}) and accuracy ({accuracy * 100:.2f}%) suggest that the model may have some bias across different groups, and further investigation is needed to ensure it does not disproportionately affect certain groups." | |
| else: | |
| statement += f"The high fairness score ({fairness_score:.2f}) and accuracy ({accuracy * 100:.2f}%) indicate that the model exhibits significant bias across different groups, and it's recommended to address this issue to ensure fair predictions for all groups." | |
| return statement | |
| def generate_visualization(visualization_type, results, label_map, chart_mode): | |
| true_labels = [r[1] for r in results] | |
| pred_labels = [r[2] for r in results] | |
| background_color = "white" if chart_mode == "Light" else "black" | |
| text_color = "black" if chart_mode == "Light" else "white" | |
| if visualization_type == "confusion_matrix": | |
| return generate_report_card(results, label_map, chart_mode)["fig"] | |
| elif visualization_type == "per_class_accuracy": | |
| per_class_accuracy = calculate_per_class_metrics( | |
| true_labels, pred_labels, label_map, metric='accuracy') | |
| colors = px.colors.qualitative.Plotly | |
| fig = go.Figure() | |
| for i, label in enumerate(label_map.values()): | |
| fig.add_trace(go.Bar( | |
| x=[label], | |
| y=[per_class_accuracy[i]], | |
| name=label, | |
| marker_color=colors[i % len(colors)] | |
| )) | |
| fig.update_xaxes(showgrid=True, gridwidth=1, | |
| gridcolor='LightGray', linecolor='black', linewidth=1) | |
| fig.update_yaxes(showgrid=True, gridwidth=1, | |
| gridcolor='LightGray', linecolor='black', linewidth=1) | |
| fig.update_layout(plot_bgcolor=background_color, | |
| paper_bgcolor=background_color, | |
| font=dict(color=text_color), | |
| title='Per-Class Accuracy', | |
| xaxis_title='Class', yaxis_title='Accuracy' | |
| ) | |
| return fig | |
| elif visualization_type == "per_class_f1": | |
| per_class_f1 = calculate_per_class_metrics( | |
| true_labels, pred_labels, label_map, metric='f1') | |
| colors = px.colors.qualitative.Plotly | |
| fig = go.Figure() | |
| for i, label in enumerate(label_map.values()): | |
| fig.add_trace(go.Bar( | |
| x=[label], | |
| y=[per_class_f1[i]], | |
| name=label, | |
| marker_color=colors[i % len(colors)] | |
| )) | |
| fig.update_xaxes(showgrid=True, gridwidth=1, | |
| gridcolor='LightGray', linecolor='black', linewidth=1) | |
| fig.update_yaxes(showgrid=True, gridwidth=1, | |
| gridcolor='LightGray', linecolor='black', linewidth=1) | |
| fig.update_layout(plot_bgcolor=background_color, | |
| paper_bgcolor=background_color, | |
| font=dict(color=text_color), | |
| title='Per-Class F1-Score', | |
| xaxis_title='Class', yaxis_title='F1-Score' | |
| ) | |
| return fig | |
| elif visualization_type == "interactive_dashboard": | |
| return generate_interactive_dashboard(results, label_map, chart_mode) | |
| else: | |
| raise ValueError(f"Invalid visualization type: {visualization_type}") | |
| def generate_interactive_dashboard(results, label_map, chart_mode): | |
| true_labels = [r[1] for r in results] | |
| pred_labels = [r[2] for r in results] | |
| colors = ['#EF553B', '#00CC96', '#636EFA', '#AB63FA', '#FFA15A', | |
| '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52'] | |
| background_color = "white" if chart_mode == "Light" else "black" | |
| text_color = "black" if chart_mode == "Light" else "white" | |
| # Create confusion matrix | |
| cm_fig = generate_report_card(results, label_map, chart_mode)["fig"] | |
| # Create per-class accuracy bar chart | |
| pca_data = calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='accuracy') | |
| pca_fig = go.Bar(x=list(label_map.values()), y=pca_data, marker=dict(color=colors)) | |
| # Create per-class F1-score bar chart | |
| pcf_data = calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='f1') | |
| pcf_fig = go.Bar(x=list(label_map.values()), y=pcf_data, marker=dict(color=colors)) | |
| # Combine all charts into a mixed subplot | |
| fig = make_subplots(rows=2, cols=2, shared_xaxes=True, specs=[[{"colspan": 2}, None], | |
| [{}, {}]], | |
| print_grid=True,subplot_titles=( | |
| "Confusion Matrix", "Per-Class Accuracy", "Per-Class F1-Score")) | |
| fig.add_trace(cm_fig['data'][0], row=1, col=1) | |
| fig.add_trace(pca_fig, row=2, col=1) | |
| fig.add_trace(pcf_fig, row=2, col=2) | |
| fig.update_xaxes(showgrid=True, gridwidth=1, | |
| gridcolor='LightGray', linecolor='black', linewidth=1) | |
| fig.update_yaxes(showgrid=True, gridwidth=1, | |
| gridcolor='LightGray', linecolor='black', linewidth=1) | |
| # Update layout | |
| fig.update_layout(height=700, width=650, | |
| plot_bgcolor=background_color, | |
| paper_bgcolor=background_color, | |
| font=dict(color=text_color), | |
| title="Fairness Report", showlegend=False | |
| ) | |
| return fig | |
| def generate_report_card(results, label_map, chart_mode): | |
| true_labels = [r[1] for r in results] | |
| pred_labels = [r[2] for r in results] | |
| background_color = "white" if chart_mode == "Light" else "black" | |
| text_color = "black" if chart_mode == "Light" else "white" | |
| cm = confusion_matrix(true_labels, pred_labels) | |
| # Normalize the confusion matrix | |
| cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
| # Create a custom color scale | |
| custom_color_scale = np.zeros(cm_normalized.shape, dtype='str') | |
| for i in range(cm_normalized.shape[0]): | |
| for j in range(cm_normalized.shape[1]): | |
| custom_color_scale[i, j] = '#EF553B' if i == j else '#00CC96' | |
| fig = go.Figure(go.Heatmap(z=cm_normalized, | |
| x=list(label_map.values()), | |
| y=list(label_map.values()), | |
| text=cm, | |
| hovertemplate='%{text}', | |
| colorscale=[[0, '#EF553B'], [ | |
| 1, '#00CC96']], | |
| showscale=False, | |
| zmin=0, zmax=1, | |
| customdata=custom_color_scale)) | |
| fig.update_xaxes(showgrid=True, gridwidth=1, | |
| gridcolor='LightGray', linecolor='black', linewidth=1) | |
| fig.update_yaxes(showgrid=True, gridwidth=1, | |
| gridcolor='LightGray', linecolor='black', linewidth=1) | |
| fig.update_layout( | |
| plot_bgcolor=background_color, | |
| paper_bgcolor=background_color, | |
| font=dict(color=text_color), | |
| height=500, width=600, | |
| title='Confusion Matrix', | |
| xaxis=dict(title='Predicted Labels'), | |
| yaxis=dict(title='True Labels') | |
| ) | |
| # Create the text output | |
| # accuracy = pd.Series(true_labels) == pd.Series(pred_labels) | |
| accuracy = accuracy_score(true_labels, pred_labels, normalize=False) | |
| fairness_score = calculate_fairness_score(results, label_map) | |
| per_class_accuracy = calculate_per_class_metrics( | |
| true_labels, pred_labels, label_map, metric='accuracy') | |
| per_class_f1 = calculate_per_class_metrics( | |
| true_labels, pred_labels, label_map, metric='f1') | |
| report_card = { | |
| "fig": fig, | |
| "accuracy": accuracy, | |
| "fairness_score": fairness_score, | |
| "per_class_accuracy": per_class_accuracy, | |
| "per_class_f1": per_class_f1 | |
| } | |
| return report_card | |
| # return fig, text_output | |
| def generate_insights(custom_text, model_name, dataset_name, accuracy, fairness_score, report_card, generator): | |
| per_class_metrics = { | |
| 'accuracy': report_card.get('per_class_accuracy', []), | |
| 'f1': report_card.get('per_class_f1', []) | |
| } | |
| if not per_class_metrics['accuracy'] or not per_class_metrics['f1']: | |
| input_text = f"{custom_text} The model {model_name} has been evaluated on the {dataset_name} dataset. It has an overall accuracy of {accuracy * 100:.2f}%. The fairness score is {fairness_score:.2f}. Per-class metrics could not be calculated. Please provide some interesting insights about the fairness and bias of the model." | |
| else: | |
| input_text = f"{custom_text} The model {model_name} has been evaluated on the {dataset_name} dataset. It has an overall accuracy of {accuracy * 100:.2f}%. The fairness score is {fairness_score:.2f}. The per-class metrics are: {per_class_metrics}. Please provide some interesting insights about the fairness, bias, and per-class performance." | |
| insights = generator(input_text, max_length=600, | |
| do_sample=True, temperature=0.7) | |
| return insights[0]['generated_text'] | |
| def app(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int, visualization_type: str, chart_mode: str): | |
| tokenizer, model = load_model( | |
| model_type, model_name_or_path, dataset_name, config_name) | |
| # Load the dataset | |
| # Add this line to cast num_samples to an integer | |
| num_samples = int(num_samples) | |
| dataset = load_dataset( | |
| dataset_name, config_name, split=f"{dataset_split}[:{num_samples}]") | |
| test_data = [] | |
| if dataset_name == "glue": | |
| test_data = [(item["sentence"], None, | |
| dataset.features["label"].names[item["label"]]) for item in dataset] | |
| elif dataset_name == "tweet_eval": | |
| test_data = [(item["text"], None, dataset.features["label"].names[item["label"]]) | |
| for item in dataset] | |
| else: | |
| test_data = [(item["sentence"], None, | |
| dataset.features["label"].names[item["label"]]) for item in dataset] | |
| # if model_type == "text_classification": | |
| # for item in dataset: | |
| # text = item["sentence"] | |
| # context = None | |
| # true_label = item["label"] | |
| # test_data.append((text, context, true_label)) | |
| # elif model_type == "question_answering": | |
| # for item in dataset: | |
| # text = item["question"] | |
| # context = item["context"] | |
| # true_label = None | |
| # test_data.append((text, context, true_label)) | |
| # else: | |
| # raise ValueError(f"Invalid model type: {model_type}") | |
| label_map = generate_label_map(dataset) | |
| results = test_model(tokenizer, model, test_data, label_map) | |
| # fig, text_output = generate_report_card(results, label_map) | |
| # return fig, text_output | |
| report_card = generate_report_card(results, label_map, chart_mode) | |
| visualization = generate_visualization(visualization_type, results, label_map, chart_mode) | |
| per_class_metrics_str = "\n".join([f"{label}: Acc {acc:.2f}, F1 {f1:.2f}" for label, acc, f1 in zip( | |
| label_map.values(), report_card['per_class_accuracy'], report_card['per_class_f1'])]) | |
| accuracy, fairness_score = calculate_fairness_score(results, label_map) | |
| fairness_statement = generate_fairness_statement(accuracy, fairness_score) | |
| # Use a GPU if available, otherwise use -1 for CPU. | |
| generator = pipeline( | |
| 'text-generation', model='gpt2', device=-1) # Use EleutherAI/gpt-neo-1.3B or EleutherAI/GPT-J-6B for GPT3 for distilgpt2 for GPT2 | |
| per_class_metrics = { | |
| 'accuracy': report_card['per_class_accuracy'], | |
| 'f1': report_card['per_class_f1'] | |
| } | |
| custom_text = fairness_statement | |
| insights = generate_insights(custom_text, model_name_or_path, | |
| dataset_name, accuracy, fairness_score, report_card, generator) | |
| # return report_card["fig"], f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}" | |
| # return f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}", report_card["fig"] | |
| return (f"{insights}\n\n" | |
| f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]: .2f}\n\n" | |
| f"Per-Class Metrics:\n{per_class_metrics_str}"), visualization | |
| interface = gr.Interface( | |
| fn=app, | |
| inputs=[ | |
| gr.inputs.Radio(["text_classification", "token_classification", | |
| "question_answering"], label="Model Type", default="text_classification"), | |
| gr.inputs.Textbox(lines=1, label="Model Name or Path", | |
| placeholder="ex: distilbert-base-uncased-finetuned-sst-2-english", default="distilbert-base-uncased-finetuned-sst-2-english"), | |
| gr.inputs.Textbox(lines=1, label="Dataset Name", | |
| placeholder="ex: glue", default="glue"), | |
| gr.inputs.Textbox(lines=1, label="Config Name", | |
| placeholder="ex: sst2", default="cola"), | |
| gr.inputs.Dropdown( | |
| choices=["train", "validation", "test"], label="Dataset Split", default="validation"), | |
| gr.inputs.Number(default=100, label="Number of Samples"), | |
| gr.inputs.Dropdown( | |
| choices=["interactive_dashboard", "confusion_matrix", "per_class_accuracy", "per_class_f1"], label="Visualization Type", default="interactive_dashboard" | |
| ), | |
| gr.inputs.Radio(["Light", "Dark"], label="Chart Mode", default="Light"), | |
| ], | |
| # outputs=gr.Plot(), | |
| # outputs=gr.outputs.HTML(), | |
| # outputs=[gr.outputs.HTML(), gr.Plot()], | |
| outputs=[ | |
| gr.outputs.Textbox(label="Fairness and Bias Metrics"), | |
| gr.Plot(label="Graph") | |
| ], | |
| title="Fairness and Bias Testing", | |
| description="Enter a model and dataset to test for fairness and bias.", | |
| ) | |
| # Define the label map globally | |
| label_map = {0: "negative", 1: "positive"} | |
| if __name__ == "__main__": | |
| interface.launch() | |