Spaces:
Running
Running
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| from scipy.special import logit | |
| df = pd.read_json("../results.json") | |
| df = df[df["metric"] != "chrf"] | |
| df = df.groupby(["task", "metric", "bcp_47"]).agg({"score": "mean"}).reset_index() | |
| # Apply logit transformation to classification scores to reduce skewness | |
| def transform_classification_scores(row): | |
| if row['task'] == 'classification': | |
| # Avoid division by zero and infinite values by clipping | |
| score = np.clip(row['score'], 0.001, 0.999) | |
| # Apply logit transformation (log(p/(1-p))) | |
| return logit(score) | |
| else: | |
| return row['score'] | |
| df['score'] = df.apply(transform_classification_scores, axis=1) | |
| # Create a pivot table with tasks as columns and languages as rows | |
| pivot_df = df.pivot_table( | |
| values='score', | |
| index='bcp_47', | |
| columns='task', | |
| aggfunc='mean' | |
| ) | |
| # Sort and filter tasks | |
| ordered_tasks = [ | |
| 'translation_from', | |
| 'translation_to', | |
| 'classification', | |
| 'mmlu', | |
| 'arc', | |
| 'mgsm', | |
| ] | |
| # Drop 'truthfulqa' if present and reindex columns | |
| pivot_df = pivot_df[[task for task in ordered_tasks if task in pivot_df.columns]] | |
| # Calculate correlation matrix | |
| correlation_matrix = pivot_df.corr() | |
| # Create the correlation plot | |
| plt.figure(figsize=(8, 6)) | |
| # Create mask for upper triangle including diagonal to show only lower triangle | |
| mask = np.triu(np.ones_like(correlation_matrix, dtype=bool)) | |
| # Create a heatmap | |
| sns.heatmap( | |
| correlation_matrix, | |
| annot=True, | |
| cmap='Blues', | |
| center=0, | |
| square=True, | |
| mask=mask, | |
| cbar_kws={"shrink": .8}, | |
| fmt='.3f' | |
| ) | |
| plt.xlabel('Tasks', fontsize=12) | |
| plt.ylabel('Tasks', fontsize=12) | |
| plt.xticks(rotation=45, ha='right') | |
| plt.yticks(rotation=0) | |
| plt.tight_layout() | |
| # Save the plot | |
| plt.savefig('task_correlation_matrix.png', dpi=300, bbox_inches='tight') | |
| plt.show() | |
| # Print correlation values for reference | |
| print("Correlation Matrix:") | |
| print("Note: Classification scores have been logit-transformed to reduce skewness") | |
| print(correlation_matrix.round(3)) | |
| # Also create a scatter plot matrix for pairwise relationships with highlighted languages | |
| highlighted_languages = ['en', 'zh', 'hi', 'es', 'ar'] | |
| # Create color mapping | |
| def get_color_and_label(lang_code): | |
| if lang_code in highlighted_languages: | |
| color_map = {'en': 'red', 'zh': 'blue', 'hi': 'green', 'es': 'orange', 'ar': 'purple'} | |
| return color_map[lang_code], lang_code | |
| else: | |
| return 'lightgray', 'Other' | |
| # Create custom scatter plot matrix | |
| tasks = pivot_df.columns.tolist() | |
| n_tasks = len(tasks) | |
| fig, axes = plt.subplots(n_tasks, n_tasks, figsize=(15, 12)) | |
| fig.suptitle('Pairwise Task Performance', fontsize=16, fontweight='bold') | |
| # Create legend elements | |
| legend_elements = [] | |
| for lang in highlighted_languages: | |
| color, _ = get_color_and_label(lang) | |
| legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=8, label=lang)) | |
| legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgray', markersize=8, label='Other')) | |
| for i, task_y in enumerate(tasks): | |
| for j, task_x in enumerate(tasks): | |
| ax = axes[i, j] | |
| if i == j: | |
| # Diagonal: histogram | |
| task_data = pivot_df[task_y].dropna() | |
| colors = [get_color_and_label(lang)[0] for lang in task_data.index] | |
| ax.hist(task_data, bins=20, alpha=0.7, color='skyblue', edgecolor='black') | |
| ax.set_title(f'{task_y}', fontsize=10) | |
| else: | |
| # Off-diagonal: scatter plot | |
| for lang_code in pivot_df.index: | |
| if pd.notna(pivot_df.loc[lang_code, task_x]) and pd.notna(pivot_df.loc[lang_code, task_y]): | |
| color, _ = get_color_and_label(lang_code) | |
| alpha = 0.8 if lang_code in highlighted_languages else 0.3 | |
| size = 50 if lang_code in highlighted_languages else 20 | |
| ax.scatter(pivot_df.loc[lang_code, task_x], pivot_df.loc[lang_code, task_y], | |
| c=color, alpha=alpha, s=size) | |
| # Set labels | |
| if i == n_tasks - 1: | |
| ax.set_xlabel(task_x, fontsize=10) | |
| if j == 0: | |
| ax.set_ylabel(task_y, fontsize=10) | |
| # Remove tick labels except for edges | |
| if i != n_tasks - 1: | |
| ax.set_xticklabels([]) | |
| if j != 0: | |
| ax.set_yticklabels([]) | |
| # Add legend | |
| fig.legend( | |
| handles=legend_elements, | |
| loc='lower center', | |
| bbox_to_anchor=(0.5, -0.05), | |
| ncol=len(legend_elements), | |
| frameon=False, | |
| fontsize=10, | |
| handletextpad=0.5, | |
| columnspacing=1.0 | |
| ) | |
| plt.tight_layout() | |
| plt.savefig('task_scatter_matrix.png', dpi=300, bbox_inches='tight') | |
| plt.show() | |