Spaces:
Runtime error
Runtime error
Update visualization.py
Browse files- visualization.py +19 -14
visualization.py
CHANGED
|
@@ -231,26 +231,24 @@ def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
|
| 231 |
plt.tight_layout()
|
| 232 |
return plt.gcf()
|
| 233 |
|
| 234 |
-
def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="
|
| 235 |
-
plt.figure(figsize=(20,
|
| 236 |
-
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20,
|
| 237 |
|
| 238 |
# Face heatmap
|
| 239 |
-
sns.heatmap(mse_face.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax1)
|
| 240 |
-
ax1.
|
| 241 |
-
ax1.
|
| 242 |
-
ax1.set_xticks([])
|
| 243 |
|
| 244 |
# Posture heatmap
|
| 245 |
-
sns.heatmap(mse_posture.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax2)
|
| 246 |
-
ax2.
|
| 247 |
-
ax2.
|
| 248 |
-
ax2.set_xticks([])
|
| 249 |
|
| 250 |
# Voice heatmap
|
| 251 |
-
sns.heatmap(mse_voice.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax3)
|
| 252 |
-
ax3.
|
| 253 |
-
ax3.
|
| 254 |
|
| 255 |
# Set x-axis ticks to timecodes for the bottom subplot
|
| 256 |
num_ticks = min(60, len(mse_voice))
|
|
@@ -259,6 +257,13 @@ def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Stack
|
|
| 259 |
ax3.set_xticks(tick_locations)
|
| 260 |
ax3.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
plt.suptitle(title)
|
| 263 |
plt.tight_layout()
|
| 264 |
plt.close()
|
|
|
|
| 231 |
plt.tight_layout()
|
| 232 |
return plt.gcf()
|
| 233 |
|
| 234 |
+
def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Combined MSE Heatmaps"):
|
| 235 |
+
plt.figure(figsize=(20, 6), dpi=300)
|
| 236 |
+
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 6), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1.2], 'hspace': 0})
|
| 237 |
|
| 238 |
# Face heatmap
|
| 239 |
+
sns.heatmap(mse_face.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax1, xticklabels=False, yticklabels=False)
|
| 240 |
+
ax1.set_ylabel('Face', rotation=0, ha='right', va='center')
|
| 241 |
+
ax1.yaxis.set_label_coords(-0.01, 0.5)
|
|
|
|
| 242 |
|
| 243 |
# Posture heatmap
|
| 244 |
+
sns.heatmap(mse_posture.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax2, xticklabels=False, yticklabels=False)
|
| 245 |
+
ax2.set_ylabel('Posture', rotation=0, ha='right', va='center')
|
| 246 |
+
ax2.yaxis.set_label_coords(-0.01, 0.5)
|
|
|
|
| 247 |
|
| 248 |
# Voice heatmap
|
| 249 |
+
sns.heatmap(mse_voice.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax3, yticklabels=False)
|
| 250 |
+
ax3.set_ylabel('Voice', rotation=0, ha='right', va='center')
|
| 251 |
+
ax3.yaxis.set_label_coords(-0.01, 0.5)
|
| 252 |
|
| 253 |
# Set x-axis ticks to timecodes for the bottom subplot
|
| 254 |
num_ticks = min(60, len(mse_voice))
|
|
|
|
| 257 |
ax3.set_xticks(tick_locations)
|
| 258 |
ax3.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
|
| 259 |
|
| 260 |
+
# Remove spines
|
| 261 |
+
for ax in [ax1, ax2, ax3]:
|
| 262 |
+
ax.spines['top'].set_visible(False)
|
| 263 |
+
ax.spines['right'].set_visible(False)
|
| 264 |
+
ax.spines['bottom'].set_visible(False)
|
| 265 |
+
ax.spines['left'].set_visible(False)
|
| 266 |
+
|
| 267 |
plt.suptitle(title)
|
| 268 |
plt.tight_layout()
|
| 269 |
plt.close()
|