Spaces:
Running
Running
Heatmap: Taken create_heatmap() from Jupyter notebook created by Martin Fajcik
Browse files- analyze_winscore.py +109 -0
analyze_winscore.py
CHANGED
|
@@ -207,4 +207,113 @@ def create_scatter_plot_with_curve_with_variances_named(category, variance_acros
|
|
| 207 |
|
| 208 |
return p
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
# EOF
|
|
|
|
| 207 |
|
| 208 |
return p
|
| 209 |
|
| 210 |
+
def create_heatmap(data_matrix, original_scores, selected_rows=None, hide_scores_tasks=[], width=800, height=1400):
|
| 211 |
+
plot_width = width
|
| 212 |
+
plot_height = height
|
| 213 |
+
n_rows, n_cols = data_matrix.shape
|
| 214 |
+
|
| 215 |
+
# Clean column names (remove 'benczechmark_' prefix)
|
| 216 |
+
# data_matrix.columns = data_matrix.columns.str.replace('benczechmark_', '', regex=False)
|
| 217 |
+
# original_scores.columns = original_scores.columns.str.replace('benczechmark_', '', regex=False)
|
| 218 |
+
|
| 219 |
+
if selected_rows is not None:
|
| 220 |
+
# Select only the specified rows (models)
|
| 221 |
+
data_matrix = data_matrix[selected_rows]
|
| 222 |
+
original_scores = original_scores[selected_rows]
|
| 223 |
+
|
| 224 |
+
# Set up the figure with tasks as x-axis and models as y-axis
|
| 225 |
+
p = figure(
|
| 226 |
+
width=plot_width, height=plot_height,
|
| 227 |
+
x_range=list(data_matrix.index), y_range=list(data_matrix.columns),
|
| 228 |
+
toolbar_location="below", tools="pan,wheel_zoom,box_zoom,reset,save",
|
| 229 |
+
x_axis_label="Model", y_axis_label="Category"
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Create the color mapper for the heatmap
|
| 233 |
+
color_mapper = LinearColorMapper(palette='Viridis256', low=0, high=1) # Light for low values, dark for high
|
| 234 |
+
|
| 235 |
+
# Flatten the matrix for Bokeh plotting
|
| 236 |
+
heatmap_data = {
|
| 237 |
+
'x': [],
|
| 238 |
+
'y': [],
|
| 239 |
+
'colors': [],
|
| 240 |
+
'model_names': [], # Updated: Reflects model names now
|
| 241 |
+
'scores': [],
|
| 242 |
+
}
|
| 243 |
+
label_data = {
|
| 244 |
+
'x': [],
|
| 245 |
+
'y': [],
|
| 246 |
+
'value': [],
|
| 247 |
+
'text_color': [], # New field for label text colors
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
# Iterate through the data_matrix to populate heatmap and label data
|
| 251 |
+
for row_idx, (model_name, task_scores) in enumerate(data_matrix.iterrows()):
|
| 252 |
+
for col_idx, score in enumerate(task_scores):
|
| 253 |
+
heatmap_data['x'].append(model_name) # Model goes to x-axis
|
| 254 |
+
heatmap_data['y'].append(data_matrix.columns[col_idx]) # Task goes to y-axis
|
| 255 |
+
heatmap_data['colors'].append(score)
|
| 256 |
+
heatmap_data['model_names'].append(model_name) # Model names added to hover info
|
| 257 |
+
|
| 258 |
+
# Get the original score
|
| 259 |
+
original_score = original_scores.loc[model_name, data_matrix.columns[col_idx]]
|
| 260 |
+
plot_score = data_matrix.loc[model_name, data_matrix.columns[col_idx]]
|
| 261 |
+
heatmap_data['scores'].append(original_score)
|
| 262 |
+
task_name = data_matrix.columns[col_idx]
|
| 263 |
+
|
| 264 |
+
if task_name not in hide_scores_tasks:
|
| 265 |
+
label_data['x'].append(model_name)
|
| 266 |
+
label_data['y'].append(task_name)
|
| 267 |
+
label_data['value'].append(round(original_score)) # Round the score
|
| 268 |
+
|
| 269 |
+
# Determine text color based on score
|
| 270 |
+
if plot_score <= 0.6: # Threshold for light/dark text
|
| 271 |
+
label_data['text_color'].append('white') # Light color for lower scores
|
| 272 |
+
else:
|
| 273 |
+
label_data['text_color'].append('black') # Dark color for higher scores
|
| 274 |
+
|
| 275 |
+
heatmap_source = ColumnDataSource(heatmap_data)
|
| 276 |
+
label_source = ColumnDataSource(label_data)
|
| 277 |
+
|
| 278 |
+
# Create the heatmap
|
| 279 |
+
p.rect(x='x', y='y', width=1, height=1, source=heatmap_source,
|
| 280 |
+
line_color=None, fill_color={'field': 'colors', 'transform': color_mapper})
|
| 281 |
+
|
| 282 |
+
# Add color bar
|
| 283 |
+
# Add color bar with custom ticks
|
| 284 |
+
color_bar = ColorBar(
|
| 285 |
+
color_mapper=color_mapper,
|
| 286 |
+
width=8, location=(0, 0),
|
| 287 |
+
ticker=FixedTicker(ticks=[0, 0.2, 0.4, 0.6, 0.8, 1]), # Fixed ticks at 0, 20, 40, 60, 80, 100
|
| 288 |
+
major_label_overrides={0: '0', 0.2: '20', 0.4: '40', 0.6: '60', 0.8: '80', 1: '100'} # Custom labels for ticks
|
| 289 |
+
)
|
| 290 |
+
#p.add_layout(color_bar, 'right')
|
| 291 |
+
|
| 292 |
+
# Add HoverTool for interactivity
|
| 293 |
+
hover = HoverTool()
|
| 294 |
+
hover.tooltips = [("Model", "@x"), ("Task", "@y"), ("DS", "@scores")] # Updated tooltip
|
| 295 |
+
p.add_tools(hover)
|
| 296 |
+
|
| 297 |
+
# Add labels with dynamic text color
|
| 298 |
+
labels = LabelSet(x='x', y='y', text='value', source=label_source,
|
| 299 |
+
text_color='text_color', text_align='center', text_baseline='middle')
|
| 300 |
+
p.add_layout(labels)
|
| 301 |
+
|
| 302 |
+
# Customize the plot appearance
|
| 303 |
+
p.xgrid.grid_line_color = None
|
| 304 |
+
p.ygrid.grid_line_color = None
|
| 305 |
+
p.xaxis.major_label_orientation = "vertical"
|
| 306 |
+
p.yaxis.major_label_text_font_size = "13pt"
|
| 307 |
+
p.xaxis.major_label_text_font_size = "13pt"
|
| 308 |
+
|
| 309 |
+
# Set the axis label font size
|
| 310 |
+
p.xaxis.axis_label_text_font_size = "18pt" # Set font size for x-axis label
|
| 311 |
+
p.yaxis.axis_label_text_font_size = "18pt" # Set font size for y-axis label
|
| 312 |
+
p.xaxis.axis_label_text_font_style = "normal" # Set x-axis label to normal
|
| 313 |
+
p.yaxis.axis_label_text_font_style = "normal" # Set y-axis label to normal
|
| 314 |
+
|
| 315 |
+
#p.yaxis.visible = False # Hide the y-axis labels
|
| 316 |
+
|
| 317 |
+
return p
|
| 318 |
+
|
| 319 |
# EOF
|