Spaces:
Running
Running
송종윤/AI Productivity팀(SR)/삼성전자
add models, add speed and time results, change scatter plot design
a452b10
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from plotly.graph_objs._figure import Figure | |
| from typing import Optional, List, Dict, Any | |
| from src.display.formatting import get_display_model_name | |
| SORT_COLUMN_MAP = { | |
| "Average Accuracy": "Avg AC", | |
| "Tool Selection Quality": "Avg TSQ", | |
| "Session Cost": "Avg Total Cost" | |
| } | |
| def get_theme_colors(theme: str = "light") -> Dict[str, Any]: | |
| """Return color settings for the given theme.""" | |
| if theme == "dark": | |
| return { | |
| "paper_bg": "#181c3a", # darker blue-gray | |
| "plot_bg": "#181c3a", | |
| "legend_font_color": "#F5F6F7", | |
| "legend_bg": 'rgba(35,36,74,0.92)', # slightly lighter than bg, but still dark | |
| "annotation_color": '#F5F6F7' | |
| } | |
| else: | |
| return { | |
| "paper_bg": "#23244a", # deep blue-gray | |
| "plot_bg": "#23244a", | |
| "legend_font_color": "#F5F6F7", | |
| "legend_bg": 'rgba(35,36,74,0.92)', # match bg for harmony | |
| "annotation_color": '#F5F6F7' | |
| } | |
| def create_empty_radar_chart(message: str) -> Figure: | |
| """Create an empty radar chart with a message.""" | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text=f"📊 {message}", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, | |
| xanchor='center', yanchor='middle', | |
| font=dict( | |
| size=18, | |
| color="#94A3B8", | |
| family="Verdana, sans-serif" | |
| ), | |
| showarrow=False, | |
| bgcolor="rgba(245, 246, 247, 0.05)", | |
| bordercolor="rgba(245, 246, 247, 0.2)", | |
| borderwidth=1, | |
| borderpad=20 | |
| ) | |
| fig.update_layout( | |
| paper_bgcolor="#01091A", | |
| plot_bgcolor="rgba(245, 246, 247, 0.02)", | |
| height=800, | |
| width=800, | |
| margin=dict(t=100, b=80, l=80, r=80), | |
| title=dict( | |
| text="<b>Domain Performance Chart</b>", | |
| x=0.5, | |
| y=0.97, | |
| font=dict( | |
| size=22, | |
| family="Verdana, sans-serif", | |
| color="#F5F6F7", | |
| weight=700 | |
| ), | |
| ), | |
| annotations=[ | |
| dict( | |
| text="TRUEBench", | |
| xref="paper", yref="paper", | |
| x=0.98, y=0.02, | |
| xanchor='right', yanchor='bottom', | |
| font=dict(size=10, color='#64748B'), | |
| showarrow=False | |
| ) | |
| ] | |
| ) | |
| return fig | |
| def create_len_overall_scatter( | |
| df: pd.DataFrame, | |
| selected_models: Optional[List[str]] = None, | |
| max_models: int = 50, | |
| y_col: str = "Overall", | |
| length_data: Optional[dict] = None, | |
| theme: str = "light", | |
| x_axis_data_source: str = "Median Length", | |
| mode: str = "open" | |
| ) -> Figure: | |
| """ | |
| Create scatter plot showing Med. Len. vs selected y_col for up to 10 selected models. | |
| Each dot is colored by Think (normal/reasoning), and the legend is by Think. | |
| DataFrame must include an 'Think' column. | |
| length_data: JSON data containing model length information by category | |
| theme: "light" or "dark" (default: "light") | |
| """ | |
| import plotly.express as px | |
| import json | |
| x_axis_data_source = "Med. Len." if x_axis_data_source == "Median Length" else "Med. Resp. Len." | |
| # Defensive: check required columns | |
| required_cols = ['Model Name', 'Med. Len.', 'Med. Resp. Len.', y_col] | |
| for col in required_cols: | |
| if col not in df.columns: | |
| return create_empty_radar_chart(f"Column '{col}' not found in data") | |
| # Think column check | |
| think_col = None | |
| for candidate in ['Think']: | |
| if candidate in df.columns: | |
| think_col = candidate | |
| break | |
| if think_col is None: | |
| return create_empty_radar_chart("Column 'Think' not found in data") | |
| # Filter by selected_models | |
| if selected_models is not None and len(selected_models) > 0: | |
| df_filtered = df[df['Model Name'].isin(selected_models)].copy() | |
| else: | |
| # Default: top-N by Overall | |
| df_filtered = df.copy() | |
| df_filtered = df_filtered.sort_values('Overall', ascending=False).head(max_models) | |
| if df_filtered.empty: | |
| return create_empty_radar_chart(f"No data available for {x_axis_data_source} vs {y_col} analysis") | |
| # Determine x-axis data based on x_axis_data_source | |
| x_axis_col_name = x_axis_data_source # Use this for the DataFrame column | |
| length_data_key = 'Med' if x_axis_data_source == "Med. Len." else 'Med Resp' | |
| if y_col == "Overall": | |
| # For 'Overall' category, prefer direct DataFrame column reading | |
| df_filtered[x_axis_col_name] = pd.to_numeric(df_filtered[x_axis_col_name], errors='coerce') | |
| elif length_data: | |
| # For other categories, use length_data if available | |
| df_filtered[x_axis_col_name] = df_filtered['Model Name'].apply( | |
| lambda x: length_data.get(x, {}).get(y_col, {}).get(length_data_key, 0) | |
| ) | |
| else: | |
| # Fallback if no length_data and not 'Overall' (though this case should ideally be handled by required_cols) | |
| df_filtered[x_axis_col_name] = pd.to_numeric(df_filtered[x_axis_col_name], errors='coerce') | |
| df_filtered[y_col] = pd.to_numeric(df_filtered[y_col], errors='coerce') | |
| if 'Type' in df_filtered.columns: | |
| df_filtered = df_filtered[df_filtered['Type'] != 'Proprietary'] | |
| legend_name_map = { | |
| 'On': 'Thinking', | |
| 'Off': 'Non-Thinking' | |
| } | |
| df_filtered['ThinkDisplay'] = df_filtered['Think'].map(legend_name_map).fillna(df_filtered['Think']) | |
| df_filtered['MarkerType'] = df_filtered['ThinkDisplay'].map({'Thinking': 'circle', 'Non-Thinking': 'square'}).fillna('circle') | |
| import numpy as np | |
| param_size = pd.to_numeric(df_filtered['Parameter Size (B)'], errors='coerce') | |
| log_param = np.log10(param_size.clip(lower=1)) | |
| log_param = log_param.fillna(0) | |
| norm = (log_param - 0) / (3 - 0) | |
| # blue (#00BFFF) ~ orange (#FF8800) linear interpolation | |
| def lerp_color(c1, c2, t): | |
| c1 = np.array([int(c1[i:i+2], 16) for i in (1, 3, 5)]) | |
| c2 = np.array([int(c2[i:i+2], 16) for i in (1, 3, 5)]) | |
| rgb = (1-t) * c1 + t * c2 | |
| return f'#{int(rgb[0]):02X}{int(rgb[1]):02X}{int(rgb[2]):02X}' | |
| color_list = [lerp_color('#00BFFF', '#FF8800', t) for t in norm] | |
| df_filtered['Color'] = color_list | |
| fig = go.Figure() | |
| median_x = df_filtered[x_axis_col_name].median() | |
| median_y = df_filtered[y_col].median() | |
| x_axis_display_name = x_axis_data_source.replace("Med.", "Median").replace("Len.", "Length") | |
| fig.add_vline( | |
| x=median_x, | |
| line_dash="dash", | |
| line_color="#64748B", | |
| opacity=0.6, | |
| line_width=1.5, | |
| annotation_text=f"{x_axis_display_name}", | |
| annotation_position="top right", | |
| annotation_font=dict(size=10, color="#64748B") | |
| ) | |
| fig.add_hline( | |
| y=median_y, | |
| line_dash="dash", | |
| line_color="#64748B", | |
| opacity=0.6, | |
| line_width=1.5, | |
| annotation_text=f"Median {y_col}", | |
| annotation_position="bottom right", | |
| annotation_font=dict(size=10, color="#64748B") | |
| ) | |
| for think, marker_type in [('Thinking', 'circle'), ('Non-Thinking', 'square')]: | |
| sub_df = df_filtered[df_filtered['ThinkDisplay'] == think] | |
| if sub_df.empty: | |
| continue | |
| marker_size = 25 if marker_type == 'square' else 30 | |
| fig.add_trace(go.Scatter( | |
| x=sub_df[x_axis_col_name], | |
| y=sub_df[y_col], | |
| mode='markers+text', | |
| name=think, | |
| legendgroup=think, | |
| showlegend=True, | |
| marker_symbol=marker_type, | |
| marker=dict( | |
| size=marker_size, | |
| color=sub_df['Color'], | |
| opacity=0.85, | |
| line=dict(width=2, color='#01091A') | |
| ), | |
| text=sub_df['Model Name'].apply(get_display_model_name), | |
| textposition="top center", | |
| textfont=dict(size=10, color='#94A3B8'), | |
| hovertemplate="<b>%{text}</b><br>" + | |
| x_axis_display_name + ": %{x:.2f}<br>" + | |
| y_col + ": %{y:.2f}<br>" + | |
| "Think: " + think + "<br>" + | |
| "Parameter Size: %{customdata}B<br>" + | |
| "<extra></extra>", | |
| customdata=sub_df['Parameter Size (B)'].values | |
| )) | |
| # colorbar는 log10(Parameter Size (B)) 0~3, tickvals=[0,1,2,3], ticktext=['1','10','100','1000'] | |
| import plotly.colors | |
| theme_colors = get_theme_colors(theme) | |
| colorbar_trace = go.Scatter( | |
| x=[None], y=[None], | |
| mode='markers', | |
| marker=dict( | |
| size=0.1, | |
| color=[0, 1, 2, 3], | |
| colorscale=[[0, '#00BFFF'], [1, '#FF8800']], | |
| cmin=0, cmax=3, | |
| colorbar=dict( | |
| title={ | |
| 'text': 'Parameter Size (B)', | |
| 'font': dict( | |
| color=theme_colors["legend_font_color"], | |
| family="Verdana, sans-serif", | |
| size=14 | |
| ) | |
| }, | |
| tickvals=[0, 1, 2, 3], | |
| ticktext=['1', '10', '100', '1000'], | |
| tickfont=dict( | |
| color=theme_colors["legend_font_color"], | |
| family="Verdana, sans-serif", | |
| size=12 | |
| ), | |
| lenmode='pixels', | |
| len=500, | |
| thickness=36, | |
| x=1.02, | |
| y=0.5, | |
| yanchor='middle' | |
| ), | |
| showscale=True | |
| ), | |
| showlegend=False, | |
| hoverinfo='none' | |
| ) | |
| fig.add_trace(colorbar_trace) | |
| # Theme colors | |
| theme_colors = get_theme_colors(theme) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"<b>{y_col} {x_axis_display_name} and Category Score</b>", | |
| x=0.5, | |
| y=0.97, | |
| font=dict(size=22, family="Verdana, sans-serif", color=theme_colors["legend_font_color"], weight=700) | |
| ), | |
| xaxis=dict( | |
| title=dict( | |
| text=f"<b>{y_col} {x_axis_display_name}</b>", | |
| font=dict(size=16, color=theme_colors["legend_font_color"]) | |
| ), | |
| tickfont=dict(size=12, color="#94A3B8"), | |
| gridcolor="rgba(245, 246, 247, 0.1)", | |
| zerolinecolor="rgba(245, 246, 247, 0.2)" | |
| ), | |
| yaxis=dict( | |
| title=dict( | |
| text=f"<b>{y_col} Score</b>", | |
| font=dict(size=16, color=theme_colors["legend_font_color"]) | |
| ), | |
| tickfont=dict(size=12, color="#94A3B8"), | |
| gridcolor="rgba(245, 246, 247, 0.1)", | |
| zerolinecolor="rgba(245, 246, 247, 0.2)" | |
| ), | |
| paper_bgcolor=theme_colors["paper_bg"], | |
| plot_bgcolor=theme_colors["plot_bg"], | |
| height=900, | |
| width=1450, | |
| showlegend=True, | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=1, | |
| xanchor="center", | |
| x=0.5, | |
| font=dict(size=12, family="Verdana, sans-serif", color=theme_colors["legend_font_color"]), | |
| bgcolor=theme_colors["legend_bg"], | |
| bordercolor='rgba(245, 246, 247, 0.2)', | |
| borderwidth=1 | |
| ), | |
| margin=dict(t=100, b=80, l=80, r=80) | |
| ) | |
| return fig | |
| def create_language_radar_chart( | |
| df: pd.DataFrame, | |
| metric_type: str, | |
| selected_models: Optional[List[str]] = None, | |
| max_models: int = 5, | |
| theme: str = "light", | |
| mode: str = "open" | |
| ) -> Figure: | |
| """ | |
| Create a radar chart showing model performance across languages for the selected models. | |
| theme: "light" or "dark" (default: "light") | |
| """ | |
| language_domains = ['KO', 'EN', 'JA', 'ZH', 'PL', 'DE', 'PT', 'ES', 'FR', 'IT', 'RU', 'VI'] | |
| if selected_models is None or len(selected_models) == 0: | |
| actual_metric_type = SORT_COLUMN_MAP.get(metric_type, metric_type) | |
| if actual_metric_type in df.columns: | |
| selected_models = df.nlargest(max_models, actual_metric_type)['Model Name'].tolist() | |
| else: | |
| selected_models = df.head(max_models)['Model Name'].tolist() | |
| selected_models = selected_models[:max_models] | |
| harmonious_palette_light = [ | |
| {'fill': 'rgba(79,143,198,0.25)', 'line': '#4F8FC6', 'name': 'BlueGray'}, | |
| {'fill': 'rgba(109,213,237,0.25)', 'line': '#6DD5ED', 'name': 'SkyBlue'}, | |
| {'fill': 'rgba(162,89,247,0.25)', 'line': '#A259F7', 'name': 'Violet'}, | |
| {'fill': 'rgba(67,233,123,0.25)', 'line': '#43E97B', 'name': 'Mint'}, | |
| {'fill': 'rgba(255,215,0,0.20)', 'line': '#FFD700', 'name': 'Gold'} | |
| ] | |
| harmonious_palette_dark = [ | |
| {'fill': 'rgba(144,202,249,0.25)', 'line': '#90CAF9', 'name': 'LightBlue'}, | |
| {'fill': 'rgba(128,203,196,0.25)', 'line': '#80CBC4', 'name': 'Mint'}, | |
| {'fill': 'rgba(179,157,219,0.25)', 'line': '#B39DDB', 'name': 'Lavender'}, | |
| {'fill': 'rgba(244,143,177,0.25)', 'line': '#F48FB1', 'name': 'Pink'}, | |
| {'fill': 'rgba(255,213,79,0.20)', 'line': '#FFD54F', 'name': 'Gold'} | |
| ] | |
| palette = harmonious_palette_light if theme == "light" else harmonious_palette_dark | |
| fig = go.Figure() | |
| for idx, model_name in enumerate(selected_models): | |
| model_data = df[df['Model Name'] == model_name] | |
| if model_data.empty: | |
| continue | |
| model_row = model_data.iloc[0] | |
| values = [] | |
| for lang in language_domains: | |
| val = model_row[lang] if lang in model_row else 0 | |
| if pd.isna(val) or val == '': | |
| val = 0 | |
| else: | |
| val = float(val) | |
| values.append(val) | |
| values_plot = values + [values[0]] | |
| domains_plot = language_domains + [language_domains[0]] | |
| colors = palette[idx % len(palette)] | |
| fig.add_trace( | |
| go.Scatterpolar( | |
| r=values_plot, | |
| theta=domains_plot, | |
| fill='toself', | |
| fillcolor=colors['fill'], | |
| line=dict( | |
| color=colors['line'], | |
| width=3, | |
| shape='spline', | |
| smoothing=0.5 | |
| ), | |
| marker=dict( | |
| size=10, | |
| color=colors['line'], | |
| symbol='circle', | |
| line=dict(width=2, color='#01091A' if theme == "light" else '#e3e6f3') | |
| ), | |
| name=get_display_model_name(model_name), | |
| mode="lines+markers", | |
| hovertemplate="<b>%{fullData.name}</b><br>" + | |
| "<span style='color: #94A3B8'>%{theta}</span><br>" + | |
| "<b style='font-size: 12px'>%{r:.3f}</b><br>" + | |
| "<extra></extra>", | |
| hoverlabel=dict( | |
| bgcolor="rgba(1, 9, 26, 0.95)" if theme == "dark" else "rgba(227,230,243,0.95)", | |
| bordercolor=colors['line'], | |
| font=dict(color="#F5F6F7" if theme == "dark" else "#23244a", size=12, family="Verdana, sans-serif") | |
| ) | |
| ) | |
| ) | |
| max_range = 100.0 | |
| tick_vals = [i * max_range / 5 for i in range(6)] | |
| tick_text = [f"{val:.2f}" for val in tick_vals] | |
| theme_colors = get_theme_colors(theme) | |
| fig.update_layout( | |
| polar=dict( | |
| bgcolor=theme_colors["plot_bg"], | |
| domain=dict(x=[0,1], y=[0,1]), | |
| radialaxis=dict( | |
| visible=True, | |
| range=[0, max_range], | |
| showline=True, | |
| linewidth=2, | |
| linecolor='rgba(245, 246, 247, 0.2)', | |
| gridcolor='rgba(245, 246, 247, 0.1)', | |
| gridwidth=1, | |
| tickvals=tick_vals, | |
| ticktext=tick_text, | |
| tickfont=dict( | |
| size=11, | |
| color='#94A3B8', | |
| family="'Geist Mono', monospace" | |
| ), | |
| tickangle=0 | |
| ), | |
| angularaxis=dict( | |
| showline=True, | |
| linewidth=2, | |
| linecolor='rgba(245, 246, 247, 0.2)', | |
| gridcolor='rgba(245, 246, 247, 0.08)', | |
| tickfont=dict( | |
| size=14, | |
| family="Verdana, sans-serif", | |
| color=theme_colors["legend_font_color"], | |
| weight=600 | |
| ), | |
| ticktext=[ | |
| "📝 Content Gen", | |
| "✂️ Editing", | |
| "📊 Data Analysis", | |
| "🧠 Reasoning", | |
| "🦄 Hallucination", | |
| "🛡️ Safety", | |
| "🔁 Repetition", | |
| "📝 Summarization", | |
| "🌐 Translation", | |
| "💬 Multi-Turn" | |
| ], | |
| rotation=90, | |
| direction="clockwise", | |
| ), | |
| ), | |
| showlegend=True, | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=-0.15, | |
| xanchor="center", | |
| x=0.5, | |
| font=dict( | |
| size=12, | |
| family="Verdana, sans-serif", | |
| color=theme_colors["legend_font_color"] | |
| ), | |
| bgcolor=theme_colors["legend_bg"], | |
| bordercolor='rgba(245, 246, 247, 0.2)', | |
| borderwidth=1, | |
| itemsizing='constant', | |
| itemwidth=30 | |
| ), | |
| title=dict( | |
| text=f"<b>Language Performance</b>", | |
| x=0.5, | |
| y=0.97, | |
| font=dict( | |
| size=22, | |
| family="Verdana, sans-serif", | |
| color=theme_colors["legend_font_color"], | |
| weight=700 | |
| ), | |
| ), | |
| paper_bgcolor=theme_colors["paper_bg"], | |
| plot_bgcolor=theme_colors["plot_bg"], | |
| height=900, | |
| width=1450, | |
| margin=dict(t=100, b=80, l=80, r=80), | |
| annotations=[ | |
| dict( | |
| text="TRUEBench", | |
| xref="paper", yref="paper", | |
| x=0.98, y=0.02, | |
| xanchor='right', yanchor='bottom', | |
| font=dict(size=10, color=theme_colors["annotation_color"]), | |
| showarrow=False | |
| ) | |
| ] | |
| ) | |
| return fig | |
| def load_leaderboard_data(data_prefix: str = "open/") -> pd.DataFrame: | |
| """Load and prepare the leaderboard data (Category).""" | |
| from src.data_loader import get_category_dataframe | |
| return get_category_dataframe(processed=True, data_prefix=data_prefix) | |
| def load_leaderboard_language_data(data_prefix: str = "open/") -> pd.DataFrame: | |
| """Load and prepare the leaderboard data (Language).""" | |
| from src.data_loader import get_language_dataframe | |
| return get_language_dataframe(processed=True, data_prefix=data_prefix) | |
| def create_speed_med_bar_plot( | |
| leaderboard_df: pd.DataFrame, | |
| time_data: dict, | |
| min_size: float = 0, | |
| max_size: float = 1000, | |
| min_score: float = 0, | |
| max_score: float = 100, | |
| category: str = "Overall", | |
| theme: str = "light", | |
| x_axis_sort_by: str = "Speed per GPU", | |
| mode: str = "open" | |
| ) -> Figure: | |
| """ | |
| Create a bar plot of Speed for the selected category for each model within the selected category's score range. | |
| Bars are sorted by Speed or Overall Score, depending on x_axis_sort_by. | |
| Parameters: | |
| leaderboard_df: DataFrame with model scores (must include "Model Name" and category columns) | |
| time_data: dict with Speed values per model and category | |
| min_size: minimum parameter size for filtering (inclusive) | |
| max_size: maximum parameter size for filtering (inclusive) | |
| min_score: minimum overall score for filtering (inclusive) | |
| max_score: maximum overall score for filtering (inclusive) | |
| category: category to use for both filtering and Speed extraction (e.g., "Overall", "Content Generation", ...) | |
| theme: "light" or "dark" | |
| x_axis_sort_by: "Speed" or "Overall Score" (default: "Speed") | |
| """ | |
| import plotly.graph_objects as go | |
| from src.display.formatting import get_display_model_name | |
| # Defensive: check required columns | |
| if "Model Name" not in leaderboard_df.columns or category not in leaderboard_df.columns: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text=f"Leaderboard missing required columns for category '{category}'.", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, | |
| xanchor='center', yanchor='middle', | |
| font=dict(size=18, color="#94A3B8", family="Verdana, sans-serif"), | |
| showarrow=False | |
| ) | |
| fig.update_layout( | |
| paper_bgcolor="#01091A", | |
| plot_bgcolor="rgba(245, 246, 247, 0.02)", | |
| height=600, | |
| width=1445, | |
| margin=dict(t=100, b=80, l=80, r=80), | |
| title=dict( | |
| text=f"<b>Speed per GPU Bar Plot</b>", | |
| x=0.5, | |
| y=0.97, | |
| font=dict(size=22, family="Verdana, sans-serif", color="#F5F6F7", weight=700) | |
| ) | |
| ) | |
| return fig | |
| # Always filter to only "Open" models | |
| leaderboard_df = leaderboard_df.copy() | |
| if "Speed" in leaderboard_df.columns: | |
| leaderboard_df = leaderboard_df[leaderboard_df["Speed"] != ""] | |
| leaderboard_df["Parameter Size (B)"] = pd.to_numeric(leaderboard_df["Parameter Size (B)"], errors="coerce") | |
| leaderboard_df["Overall"] = pd.to_numeric(leaderboard_df["Overall"], errors="coerce") | |
| filtered = leaderboard_df[ | |
| (leaderboard_df["Parameter Size (B)"].isnull() | ((leaderboard_df["Parameter Size (B)"] >= min_size) & (leaderboard_df["Parameter Size (B)"] <= max_size))) & (leaderboard_df["Overall"] >= min_score) & (leaderboard_df["Overall"] <= max_score) | |
| ].copy() | |
| # Extract Speed Med and Overall for each model for the selected category | |
| speed_meds = [] | |
| for _, row in filtered.iterrows(): | |
| model = row["Model Name"] | |
| speed_med = None | |
| try: | |
| speed_med = time_data.get(model, {}).get(category, {}).get("Speed", {}).get("Med", None) | |
| except Exception: | |
| speed_med = None | |
| num_gpus = None | |
| try: | |
| num_gpus = time_data.get(model, {}).get("NUM_GPUS", 0) | |
| except Exception: | |
| num_gpus = None | |
| overall_val = None | |
| if "Overall" in leaderboard_df.columns: | |
| try: | |
| overall_val = float(leaderboard_df.loc[leaderboard_df["Model Name"] == model, "Overall"].values[0]) | |
| except Exception: | |
| overall_val = None | |
| if speed_med is not None: | |
| speed_meds.append({ | |
| "Model Name": model, | |
| "Display Name": get_display_model_name(model), | |
| "Speed per GPU": (speed_med / num_gpus) if (num_gpus is not None and num_gpus > 0) else 0, | |
| "Speed": speed_med, | |
| "Overall": overall_val, | |
| "GPU": num_gpus | |
| }) | |
| if not speed_meds: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text=f"No Speed data available for models in selected score range ({category}).", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, | |
| xanchor='center', yanchor='middle', | |
| font=dict(size=18, color="#94A3B8", family="Verdana, sans-serif"), | |
| showarrow=False | |
| ) | |
| fig.update_layout( | |
| paper_bgcolor="#01091A", | |
| plot_bgcolor="rgba(245, 246, 247, 0.02)", | |
| height=600, | |
| width=1445, | |
| margin=dict(t=100, b=80, l=80, r=80), | |
| title=dict( | |
| text=f"<b>Speed Bar Plot</b>", | |
| x=0.5, | |
| y=0.97, | |
| font=dict(size=22, family="Verdana, sans-serif", color="#F5F6F7", weight=700) | |
| ) | |
| ) | |
| return fig | |
| # Sort by selected criterion | |
| if x_axis_sort_by == "Speed": | |
| speed_meds.sort(key=lambda x: x["Speed per GPU"], reverse=True) | |
| elif x_axis_sort_by == "Overall Score": | |
| speed_meds.sort(key=lambda x: (x["Overall"] if x["Overall"] is not None else float('-inf')), reverse=True) | |
| else: | |
| speed_meds.sort(key=lambda x: x["Speed per GPU"], reverse=True) # fallback | |
| x_labels = [x["Display Name"] for x in speed_meds] | |
| y_values = [x["Speed per GPU"] for x in speed_meds] | |
| speed_values = [x.get("Speed", None) for x in speed_meds] | |
| gpu_values = [x.get("GPU", None) for x in speed_meds] | |
| # Use numpy if available, else fallback to list of tuples | |
| try: | |
| customdata = np.stack([speed_values, gpu_values], axis=-1) | |
| except ImportError: | |
| customdata = list(zip(speed_values, gpu_values)) | |
| theme_colors = get_theme_colors(theme) | |
| fig = go.Figure() | |
| # Bar plot (Speed) | |
| # Use a vivid blue-skyblue gradient for bars | |
| vivid_blues = [ | |
| "#0099FF", "#00BFFF", "#1EC8FF", "#4FC3F7", "#00CFFF", "#00B2FF", "#00AEEF", "#00C6FB", "#00E5FF", "#00B8D9" | |
| ] | |
| bar_colors = [vivid_blues[i % len(vivid_blues)] for i in range(len(x_labels))] | |
| fig.add_trace(go.Bar( | |
| x=x_labels, | |
| y=y_values, | |
| name=f"Speed per GPU", | |
| marker=dict( | |
| color=bar_colors, | |
| line=dict(color="#23244a", width=1.5) | |
| ), | |
| text=[f"{v:,.1f}" for v in y_values], | |
| textposition="auto", | |
| customdata=customdata, | |
| hovertemplate=( | |
| "<b>%{x}</b><br>" + | |
| "Speed per GPU: %{y:,.1f}<br>" + | |
| "Speed: %{customdata[0]:,.1f}<br>" + | |
| "GPU: %{customdata[1]:d}<extra></extra>" | |
| ), | |
| yaxis="y1" | |
| )) | |
| # Line plot (Overall, always shown) | |
| overall_values = [x["Overall"] for x in speed_meds] | |
| fig.add_trace(go.Scatter( | |
| x=x_labels, | |
| y=overall_values, | |
| name="Overall Score", | |
| mode="markers", | |
| line=dict(color="#FF8800", width=3), | |
| marker=dict(color="#FF8800", size=10, symbol="triangle-down"), | |
| yaxis="y2", | |
| hovertemplate="<b>%{x}</b><br>Overall: %{y:.2f}<extra></extra>" | |
| )) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"<b>Median Speed per GPU and Overall Score</b>", | |
| x=0.5, | |
| y=0.97, | |
| font=dict(size=22, family="Verdana, sans-serif", color="#F5F6F7", weight=700) | |
| ), | |
| xaxis=dict( | |
| title=dict( | |
| text="<b>Model</b>", | |
| font=dict(size=16, color=theme_colors["legend_font_color"]) | |
| ), | |
| # Use white for x-axis tick labels (model names) | |
| tickfont=dict(size=12, color="#F5F6F7"), | |
| tickangle=45, | |
| gridcolor="rgba(245, 246, 247, 0.1)", | |
| zerolinecolor="rgba(245, 246, 247, 0.2)" | |
| ), | |
| yaxis=dict( | |
| title=dict( | |
| text=f"<b>Speed per GPU</b>", | |
| # Use vivid blue for y-axis title | |
| font=dict(size=16, color="#00BFFF") | |
| ), | |
| # Use vivid blue for y-axis tick labels | |
| tickfont=dict(size=12, color="#00BFFF"), | |
| gridcolor="rgba(245, 246, 247, 0.1)", | |
| zerolinecolor="rgba(245, 246, 247, 0.2)" | |
| ), | |
| yaxis2=dict( | |
| title={ | |
| "text": "<b>Overall Score</b>", | |
| "font": dict(size=16, color="#FF8800") | |
| }, | |
| overlaying="y", | |
| side="right", | |
| showgrid=False, | |
| tickfont=dict(size=12, color="#FF8800") | |
| ), | |
| paper_bgcolor=theme_colors["paper_bg"], | |
| plot_bgcolor=theme_colors["plot_bg"], | |
| height=600, | |
| width=1445, | |
| showlegend=True, | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=1.02, | |
| xanchor="center", | |
| x=0.5, | |
| font=dict(size=12, family="Verdana, sans-serif", color=theme_colors["legend_font_color"]), | |
| bgcolor=theme_colors["legend_bg"], | |
| bordercolor='rgba(245, 246, 247, 0.2)', | |
| borderwidth=1 | |
| ), | |
| margin=dict(t=100, b=120, l=80, r=80) | |
| ) | |
| return fig | |
| def create_domain_radar_chart( | |
| df: pd.DataFrame, | |
| metric_type: str, | |
| selected_models: Optional[List[str]] = None, | |
| max_models: int = 5, | |
| theme: str = "light", | |
| mode: str = "open" | |
| ) -> Figure: | |
| """ | |
| Create a radar chart showing model performance across domains for the selected metric. | |
| theme: "light" or "dark" (default: "light") | |
| """ | |
| actual_metric_type = SORT_COLUMN_MAP.get(metric_type, metric_type) | |
| domain_mapping = { | |
| 'Avg AC': { | |
| 'Content Generation': '📝 Content Generation', | |
| 'Editing': '✂️ Editing', | |
| 'Data Analysis': '📊 Data Analysis', | |
| 'Reasoning': '🧠 Reasoning', | |
| 'Hallucination': '🦄 Hallucination', | |
| 'Safety': '🛡️ Safety', | |
| 'Repetition': '🔁 Repetition', | |
| 'Summarization': '📝 Summarization', | |
| 'Translation': '🌐 Translation', | |
| 'Multi-Turn': '💬 Multi-Turn' | |
| }, | |
| 'Avg TSQ': { | |
| 'Content Generation': 'Content Generation', | |
| 'Editing': 'Editing', | |
| 'Data Analysis': 'Data Analysis', | |
| 'Reasoning': 'Reasoning', | |
| 'Hallucination': 'Hallucination', | |
| 'Safety': 'Safety', | |
| 'Repetition': 'Repetition', | |
| 'Summarization': 'Summarization', | |
| 'Translation': 'Translation', | |
| 'Multi-Turn': 'Multi-Turn' | |
| }, | |
| 'Avg Total Cost': { | |
| 'Content Generation': 'Content Generation', | |
| 'Editing': 'Editing', | |
| 'Data Analysis': 'Data Analysis', | |
| 'Reasoning': 'Reasoning', | |
| 'Hallucination': 'Hallucination', | |
| 'Safety': 'Safety', | |
| 'Repetition': 'Repetition', | |
| 'Summarization': 'Summarization', | |
| 'Translation': 'Translation', | |
| 'Multi-Turn': 'Multi-Turn' | |
| }, | |
| 'Avg Session Duration': { | |
| 'Content Generation': 'Content Generation', | |
| 'Editing': 'Editing', | |
| 'Data Analysis': 'Data Analysis', | |
| 'Reasoning': 'Reasoning', | |
| 'Hallucination': 'Hallucination', | |
| 'Safety': 'Safety', | |
| 'Repetition': 'Repetition', | |
| 'Summarization': 'Summarization', | |
| 'Translation': 'Translation', | |
| 'Multi-Turn': 'Multi-Turn' | |
| }, | |
| 'Avg Turns': { | |
| 'Content Generation': 'Content Generation', | |
| 'Editing': 'Editing', | |
| 'Data Analysis': 'Data Analysis', | |
| 'Reasoning': 'Reasoning', | |
| 'Hallucination': 'Hallucination', | |
| 'Safety': 'Safety', | |
| 'Repetition': 'Repetition', | |
| 'Summarization': 'Summarization', | |
| 'Translation': 'Translation', | |
| 'Multi-Turn': 'Multi-Turn' | |
| } | |
| } | |
| if actual_metric_type not in domain_mapping: | |
| return create_empty_radar_chart(f"Domain breakdown not available for {metric_type}") | |
| if selected_models is None or len(selected_models) == 0: | |
| if actual_metric_type in df.columns: | |
| selected_models = df.nlargest(max_models, actual_metric_type)['Model Name'].tolist() | |
| else: | |
| selected_models = df.head(max_models)['Model Name'].tolist() | |
| selected_models = selected_models[:max_models] | |
| domains = list(domain_mapping[actual_metric_type].keys()) | |
| domain_columns = list(domain_mapping[actual_metric_type].values()) | |
| harmonious_palette_light = [ | |
| {'fill': 'rgba(79,143,198,0.25)', 'line': '#4F8FC6', 'name': 'BlueGray'}, | |
| {'fill': 'rgba(109,213,237,0.25)', 'line': '#6DD5ED', 'name': 'SkyBlue'}, | |
| {'fill': 'rgba(162,89,247,0.25)', 'line': '#A259F7', 'name': 'Violet'}, | |
| {'fill': 'rgba(67,233,123,0.25)', 'line': '#43E97B', 'name': 'Mint'}, | |
| {'fill': 'rgba(255,215,0,0.20)', 'line': '#FFD700', 'name': 'Gold'} | |
| ] | |
| harmonious_palette_dark = [ | |
| {'fill': 'rgba(144,202,249,0.25)', 'line': '#90CAF9', 'name': 'LightBlue'}, | |
| {'fill': 'rgba(128,203,196,0.25)', 'line': '#80CBC4', 'name': 'Mint'}, | |
| {'fill': 'rgba(179,157,219,0.25)', 'line': '#B39DDB', 'name': 'Lavender'}, | |
| {'fill': 'rgba(244,143,177,0.25)', 'line': '#F48FB1', 'name': 'Pink'}, | |
| {'fill': 'rgba(255,213,79,0.20)', 'line': '#FFD54F', 'name': 'Gold'} | |
| ] | |
| palette = harmonious_palette_light if theme == "light" else harmonious_palette_dark | |
| fig = go.Figure() | |
| for idx, model_name in enumerate(selected_models): | |
| model_data = df[df['Model Name'] == model_name] | |
| if model_data.empty: | |
| continue | |
| model_row = model_data.iloc[0] | |
| values = [] | |
| for domain, _ in zip(domains, domain_columns): | |
| if domain in df.columns and domain in model_row: | |
| val = model_row[domain] | |
| if pd.isna(val) or val == '': | |
| val = 0 | |
| else: | |
| val = float(val) | |
| values.append(val) | |
| else: | |
| values.append(0) | |
| values_plot = values + [values[0]] | |
| domains_plot = domains + [domains[0]] | |
| colors = palette[idx % len(palette)] | |
| fig.add_trace( | |
| go.Scatterpolar( | |
| r=values_plot, | |
| theta=domains_plot, | |
| fill='toself', | |
| fillcolor=colors['fill'], | |
| line=dict( | |
| color=colors['line'], | |
| width=3, | |
| shape='spline', | |
| smoothing=0.5 | |
| ), | |
| marker=dict( | |
| size=10, | |
| color=colors['line'], | |
| symbol='circle', | |
| line=dict(width=2, color='#01091A' if theme == "light" else '#e3e6f3') | |
| ), | |
| name=get_display_model_name(model_name), | |
| mode="lines+markers", | |
| hovertemplate="<b>%{fullData.name}</b><br>" + | |
| "<span style='color: #94A3B8'>%{theta}</span><br>" + | |
| "<b style='font-size: 12px'>%{r:.3f}</b><br>" + | |
| "<extra></extra>", | |
| hoverlabel=dict( | |
| bgcolor="rgba(1, 9, 26, 0.95)" if theme == "dark" else "rgba(227,230,243,0.95)", | |
| bordercolor=colors['line'], | |
| font=dict(color="#F5F6F7" if theme == "dark" else "#23244a", size=12, family="Verdana, sans-serif") | |
| ) | |
| ) | |
| ) | |
| max_range = 100.0 | |
| tick_vals = [i * max_range / 5 for i in range(6)] | |
| tick_text = [f"{val:.2f}" for val in tick_vals] | |
| theme_colors = get_theme_colors(theme) | |
| fig.update_layout( | |
| polar=dict( | |
| bgcolor=theme_colors["plot_bg"], | |
| radialaxis=dict( | |
| visible=True, | |
| range=[0, max_range], | |
| showline=True, | |
| linewidth=2, | |
| linecolor='rgba(245, 246, 247, 0.2)', | |
| gridcolor='rgba(245, 246, 247, 0.1)', | |
| gridwidth=1, | |
| tickvals=tick_vals, | |
| ticktext=tick_text, | |
| tickfont=dict( | |
| size=11, | |
| color='#94A3B8', | |
| family="'Geist Mono', monospace" | |
| ), | |
| tickangle=0 | |
| ), | |
| angularaxis=dict( | |
| showline=True, | |
| linewidth=2, | |
| linecolor='rgba(245, 246, 247, 0.2)', | |
| gridcolor='rgba(245, 246, 247, 0.08)', | |
| tickfont=dict( | |
| size=14, | |
| family="Verdana, sans-serif", | |
| color=theme_colors["legend_font_color"], | |
| weight=600 | |
| ), | |
| rotation=90, | |
| direction="clockwise", | |
| ), | |
| ), | |
| showlegend=True, | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=-0.15, | |
| xanchor="center", | |
| x=0.5, | |
| font=dict( | |
| size=12, | |
| family="Verdana, sans-serif", | |
| color=theme_colors["legend_font_color"] | |
| ), | |
| bgcolor=theme_colors["legend_bg"], | |
| bordercolor='rgba(245, 246, 247, 0.2)', | |
| borderwidth=1, | |
| itemsizing='constant', | |
| itemwidth=30 | |
| ), | |
| title=dict( | |
| text=f"<b>Category Performance</b>", | |
| x=0.5, | |
| y=0.97, | |
| font=dict( | |
| size=22, | |
| family="Verdana, sans-serif", | |
| color=theme_colors["legend_font_color"], | |
| weight=700 | |
| ), | |
| ), | |
| paper_bgcolor=theme_colors["paper_bg"], | |
| plot_bgcolor=theme_colors["plot_bg"], | |
| height=900, | |
| width=1450, | |
| margin=dict(t=100, b=80, l=80, r=80), | |
| annotations=[ | |
| dict( | |
| text="TRUEBench", | |
| xref="paper", yref="paper", | |
| x=0.98, y=0.02, | |
| xanchor='right', yanchor='bottom', | |
| font=dict(size=10, color=theme_colors["annotation_color"]), | |
| showarrow=False | |
| ) | |
| ] | |
| ) | |
| return fig | |