Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # @Date : 2025/2/5 16:26 | |
| # @Author : q275343119 | |
| # @File : data_page.py | |
| # @Description: | |
| import io | |
| from st_aggrid import AgGrid, JsCode, ColumnsAutoSizeMode | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| from utils.st_copy_to_clipboard import st_copy_to_clipboard | |
| from streamlit_theme import st_theme | |
| from app.backend.app_init_func import LEADERBOARD_MAP | |
| from app.backend.constant import LEADERBOARD_ICON_MAP, BASE_URL | |
| from app.backend.json_util import compress_msgpack, decompress_msgpack | |
| COLUMNS = ['model_name', 'vendor', | |
| 'embd_dtype', 'embd_dim', 'num_params', 'max_tokens', 'similarity', | |
| 'query_instruct', 'corpus_instruct', 'reference' | |
| ] | |
| HEADER_STYLE = {'fontSize': '18px'} | |
| CELL_STYLE = {'fontSize': '18px'} | |
| def is_section(group_name): | |
| for k, v in LEADERBOARD_MAP.items(): | |
| leaderboard_name = v[0][0] | |
| if group_name == leaderboard_name: | |
| return True | |
| return False | |
| def get_closed_dataset(): | |
| data_engine = st.session_state["data_engine"] | |
| closed_list = [] | |
| results = data_engine.results | |
| for result in results: | |
| if result.get("is_closed"): | |
| closed_list.append(result.get("dataset_name")) | |
| return closed_list | |
| def convert_df_to_csv(df): | |
| output = io.StringIO() | |
| df.to_csv(output, index=False) | |
| return output.getvalue() | |
| def get_column_state(): | |
| """ | |
| get column state from url | |
| """ | |
| query_params = st.query_params.get("grid_state", None) | |
| sider_bar_hidden = st.query_params.get("sider_bar_hidden", "False") | |
| if query_params: | |
| grid_state = decompress_msgpack(query_params) | |
| st.session_state.grid_state = grid_state | |
| if sider_bar_hidden.upper() == 'FALSE': | |
| st.session_state.sider_bar_hidden = False | |
| return None | |
| def render_page(group_name): | |
| grid_state = st.session_state.get("grid_state", {}) | |
| st.session_state.sider_bar_hidden = True | |
| get_column_state() | |
| if st.session_state.sider_bar_hidden: | |
| st.markdown(""" | |
| <style> | |
| [data-testid="stSidebar"] { | |
| display: none !important; | |
| } | |
| [data-testid="stSidebarNav"] { | |
| display: none !important; | |
| } | |
| [data-testid="stBaseButton-headerNoPadding"] { | |
| display: none !important; | |
| } | |
| h1#retrieval-embedding-benchmark-rteb { | |
| text-align: center; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Add theme color and grid styles | |
| st.title("Retrieval Embedding Benchmark (RTEB)") | |
| st.markdown(""" | |
| <style> | |
| :root { | |
| --theme-color: rgb(129, 150, 64); | |
| --theme-color-light: rgba(129, 150, 64, 0.2); | |
| } | |
| /* AG Grid specific overrides */ | |
| .ag-theme-alpine { | |
| --ag-selected-row-background-color: var(--theme-color-light) !important; | |
| --ag-row-hover-color: var(--theme-color-light) !important; | |
| --ag-selected-tab-color: var(--theme-color) !important; | |
| --ag-range-selection-border-color: var(--theme-color) !important; | |
| --ag-range-selection-background-color: var(--theme-color-light) !important; | |
| } | |
| .ag-row-hover { | |
| background-color: var(--theme-color-light) !important; | |
| } | |
| .ag-row-selected { | |
| background-color: var(--theme-color-light) !important; | |
| } | |
| .ag-row-focus { | |
| background-color: var(--theme-color-light) !important; | |
| } | |
| .ag-cell-focus { | |
| border-color: var(--theme-color) !important; | |
| } | |
| /* Keep existing styles */ | |
| .center-text { | |
| text-align: center; | |
| color: var(--theme-color); | |
| } | |
| .center-image { | |
| display: block; | |
| margin-left: auto; | |
| margin-right: auto; | |
| } | |
| h2 { | |
| color: var(--theme-color) !important; | |
| } | |
| .ag-header-cell { | |
| background-color: var(--theme-color) !important; | |
| color: white !important; | |
| } | |
| a { | |
| color: var(--theme-color) !important; | |
| } | |
| a:hover { | |
| color: rgba(129, 150, 64, 0.8) !important; | |
| } | |
| /* Download Button */ | |
| button[data-testid="stBaseButton-secondary"] { | |
| float: right; | |
| } | |
| /* Toast On The Top*/ | |
| div[data-testid="stToastContainer"] { | |
| position: fixed !important; | |
| z-index: 2147483647 !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # logo | |
| # st.markdown('<img src="https://www.voyageai.com/logo.svg" class="center-image" width="200">', unsafe_allow_html=True) | |
| title = f'<h2 class="center-text">{LEADERBOARD_ICON_MAP.get(group_name.capitalize(), "")} {group_name.capitalize()}</h2>' | |
| if is_section(group_name): | |
| title = f'<h2 class="center-text">{LEADERBOARD_ICON_MAP.get(group_name.capitalize() + " Leaderboard", "")} {group_name.capitalize() + " Leaderboard"}</h2>' | |
| # title | |
| st.markdown(title, unsafe_allow_html=True) | |
| data_engine = st.session_state["data_engine"] | |
| df = data_engine.jsons_to_df().copy() | |
| csv = convert_df_to_csv(df) | |
| file_name = f"{group_name.capitalize()} Leaderboard" if is_section(group_name) else group_name.capitalize() | |
| st.download_button( | |
| label="Download CSV", | |
| data=csv, | |
| file_name=f"{file_name}.csv", | |
| mime="text/csv", | |
| icon=":material/download:", | |
| ) | |
| # get columns | |
| column_list = [] | |
| avg_column = None | |
| if is_section(group_name): | |
| avg_columns = [] | |
| for column in df.columns: | |
| if column.startswith("Average"): | |
| avg_columns.insert(0, column) | |
| continue | |
| if "Average" in column: | |
| avg_columns.append(column) | |
| continue | |
| avg_column = avg_columns[0] | |
| column_list.extend(avg_columns) | |
| else: | |
| for column in df.columns: | |
| if column.startswith(group_name.capitalize() + " "): | |
| avg_column = column | |
| column_list.append(avg_column) | |
| dataset_list = [] | |
| for dataset_dict in data_engine.datasets: | |
| if dataset_dict["name"] == group_name: | |
| dataset_list = dataset_dict["datasets"] | |
| if not is_section(group_name): | |
| column_list.extend(dataset_list) | |
| closed_list = get_closed_dataset() | |
| close_avg_list = list(set(dataset_list) & set(closed_list)) | |
| df["Closed average"] = df[close_avg_list].mean(axis=1).round(2) | |
| column_list.append("Closed average") | |
| open_avg_list = list(set(dataset_list) - set(closed_list)) | |
| df["Open average"] = df[open_avg_list].mean(axis=1).round(2) | |
| column_list.append("Open average") | |
| df = df[COLUMNS + column_list].sort_values(by=avg_column, ascending=False) | |
| # rename avg column name | |
| if not is_section(group_name): | |
| new_column = avg_column.replace(group_name.capitalize(), "").strip() | |
| df.rename(columns={avg_column: new_column}, inplace=True) | |
| column_list.remove(avg_column) | |
| avg_column = new_column | |
| # setting column config | |
| grid_options = { | |
| 'columnDefs': [ | |
| { | |
| 'headerName': 'Model Name', | |
| 'field': 'model_name', | |
| 'pinned': 'left', | |
| 'sortable': False, | |
| 'headerStyle': HEADER_STYLE, | |
| 'cellStyle': CELL_STYLE, | |
| "tooltipValueGetter": JsCode( | |
| """function(p) {return p.value}""" | |
| ), | |
| "width": 250, | |
| 'cellRenderer': JsCode("""class CustomHTML { | |
| init(params) { | |
| const link = params.data.reference; | |
| this.eGui = document.createElement('div'); | |
| this.eGui.innerHTML = link ? | |
| `<a href="${link}" class="a-cell" target="_blank">${params.value} </a>` : | |
| params.value; | |
| } | |
| getGui() { | |
| return this.eGui; | |
| } | |
| }"""), | |
| 'suppressSizeToFit': True | |
| }, | |
| {'headerName': "Vendor", | |
| 'field': 'vendor', | |
| 'headerStyle': HEADER_STYLE, | |
| 'cellStyle': CELL_STYLE, | |
| # 'suppressSizeToFit': True | |
| }, | |
| {'headerName': "Overall Score", | |
| 'field': avg_column, | |
| 'headerStyle': HEADER_STYLE, | |
| 'cellStyle': CELL_STYLE, | |
| # 'suppressSizeToFit': True | |
| }, | |
| # Add Open average column definition | |
| {'headerName': 'Open Average', | |
| 'field': 'Open average', | |
| 'headerStyle': HEADER_STYLE, | |
| 'cellStyle': CELL_STYLE, | |
| # 'suppressSizeToFit': True | |
| }, | |
| {'headerName': 'Closed Average', | |
| 'field': 'Closed average', | |
| 'headerStyle': HEADER_STYLE, | |
| 'cellStyle': CELL_STYLE, | |
| # 'suppressSizeToFit': True | |
| }, | |
| { | |
| 'headerName': 'Embd Dtype', | |
| 'field': 'embd_dtype', | |
| 'headerStyle': HEADER_STYLE, | |
| 'cellStyle': CELL_STYLE, | |
| # 'suppressSizeToFit': True, | |
| }, | |
| { | |
| 'headerName': 'Embd Dim', | |
| 'field': 'embd_dim', | |
| 'headerStyle': HEADER_STYLE, | |
| 'cellStyle': CELL_STYLE, | |
| # 'suppressSizeToFit': True, | |
| }, | |
| { | |
| 'headerName': 'Number of Parameters', | |
| 'field': 'num_params', | |
| 'cellDataType': 'number', | |
| "colId": "num_params", | |
| 'headerStyle': HEADER_STYLE, | |
| 'cellStyle': CELL_STYLE, | |
| 'valueFormatter': JsCode( | |
| """function(params) { | |
| const num = params.value; | |
| if (num >= 1e9) return (num / 1e9).toFixed(2) + "B"; | |
| if (num >= 1e6) return (num / 1e6).toFixed(2) + "M"; | |
| if (num >= 1e3) return (num / 1e3).toFixed(2) + "K"; | |
| return num; | |
| }""" | |
| ), | |
| "width": 120, | |
| # 'suppressSizeToFit': True, | |
| }, | |
| { | |
| 'headerName': 'Context Length', | |
| 'field': 'max_tokens', | |
| 'headerStyle': HEADER_STYLE, | |
| 'cellStyle': CELL_STYLE, | |
| # 'suppressSizeToFit': True, | |
| }, | |
| *[{'headerName': column if "Average" not in column else column.replace("Average", "").strip().capitalize(), | |
| 'field': column, | |
| 'headerStyle': HEADER_STYLE, | |
| 'cellStyle': CELL_STYLE, | |
| "headerTooltip": column if "Average" not in column else column.replace("Average", | |
| "").strip().capitalize(), | |
| 'headerComponent': JsCode(f""" | |
| class DatasetHeaderRenderer {{ | |
| init(params) {{ | |
| this.eGui = document.createElement('div'); | |
| const columnName = params.displayName; | |
| const fieldName = params.column.colId; | |
| if (fieldName.includes('Average')) {{ | |
| // For group columns (like "Code Average", "English Average"), display as plain text | |
| this.eGui.textContent = columnName; | |
| }} else {{ | |
| // For dataset columns, create clickable link to HuggingFace dataset | |
| const link = document.createElement('a'); | |
| link.href = 'https://huggingface.co/datasets/embedding-benchmark/' + fieldName; | |
| link.target = '_blank'; | |
| link.style.color = 'white'; | |
| link.style.textDecoration = 'underline'; | |
| link.style.cursor = 'pointer'; | |
| link.textContent = columnName; | |
| // Prevent event bubbling to avoid sorting trigger | |
| link.addEventListener('click', function(e) {{ | |
| e.stopPropagation(); | |
| }}); | |
| this.eGui.appendChild(link); | |
| }} | |
| }} | |
| getGui() {{ | |
| return this.eGui; | |
| }} | |
| }} | |
| """) | |
| # 'suppressSizeToFit': True | |
| } for column in column_list if | |
| column not in (avg_column, "Closed average", "Open average")] | |
| ], | |
| 'defaultColDef': { | |
| 'filter': True, | |
| 'sortable': True, | |
| 'resizable': True, | |
| 'headerClass': "multi-line-header", | |
| 'autoHeaderHeight': True, | |
| 'width': 105 | |
| }, | |
| "autoSizeStrategy": { | |
| "type": 'fitCellContents', | |
| "colIds": [column for column in column_list if column not in (avg_column, "Closed average", "Open average")] | |
| }, | |
| "tooltipShowDelay": 500, | |
| "initialState": grid_state, | |
| } | |
| custom_css = { | |
| # Model Name Cell | |
| ".a-cell": { | |
| "display": "inline-block", | |
| "white-space": "nowrap", | |
| "overflow": "hidden", | |
| "text-overflow": "ellipsis", | |
| "width": "100%", | |
| "min-width": "0" | |
| }, | |
| # Header | |
| ".multi-line-header": { | |
| "text-overflow": "clip", | |
| "overflow": "visible", | |
| "white-space": "normal", | |
| "height": "auto", | |
| "font-family": 'Arial', | |
| "font-size": "14px", | |
| "font-weight": "bold", | |
| "padding": "10px", | |
| "text-align": "left", | |
| } | |
| , | |
| # Filter Options and Input | |
| ".ag-theme-streamlit .ag-popup": { | |
| "font-family": 'Arial', | |
| "font-size": "14px", | |
| } | |
| , ".ag-picker-field-display": { | |
| "font-family": 'Arial', | |
| "font-size": "14px", | |
| }, | |
| ".ag-input-field-input .ag-text-field-input": { | |
| "font-family": 'Arial', | |
| "font-size": "14px", | |
| } | |
| } | |
| grid = AgGrid( | |
| df, | |
| enable_enterprise_modules=False, | |
| gridOptions=grid_options, | |
| allow_unsafe_jscode=True, | |
| columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW, | |
| theme="streamlit", | |
| custom_css=custom_css, | |
| update_on=["stateUpdated"], | |
| ) | |
| def share_url(): | |
| state = grid.grid_state | |
| if state: | |
| share_link = f'{BASE_URL.replace("_", "-")}{group_name}/?grid_state={compress_msgpack(state)}' if not is_section( | |
| group_name) else f'{BASE_URL.replace("_", "-")}?grid_state={compress_msgpack(state)}' | |
| else: | |
| share_link = f'{BASE_URL.replace("_", "-")}{group_name}' | |
| st.write(share_link) | |
| theme = st_theme() | |
| if theme: | |
| theme = theme.get("base") | |
| else: | |
| theme = "light" | |
| st_copy_to_clipboard(share_link, before_copy_label='📋Push to copy', after_copy_label='✅Text copied!', | |
| theme=theme) | |
| share_btn = st.button("Share this page", icon=":material/share:") | |
| if share_btn: | |
| share_url() | |