Spaces:
Running
Running
| """Script based on streamlit to visualize data from the Well that are hosted on Hugging Face hub. | |
| Any time the state change (due to UI interaction and callbacks), | |
| the script is evaluated again. | |
| Based on the state attributes some UI component are rendered | |
| (e.g. slider for field time step). | |
| """ | |
| import pathlib | |
| import fsspec | |
| import h5py | |
| import numpy as np | |
| import pyvista as pv | |
| import streamlit as st | |
| from stpyvista.trame_backend import stpyvista | |
| # Dataset whose data will be visualized | |
| DATASET_NAMES = [ | |
| "acoustic_scattering_inclusions", | |
| "active_matter", | |
| "helmholtz_staircase", | |
| "MHD_64", | |
| "shear_flow", | |
| ] | |
| DIM_SUFFIXES = ["x", "y", "z"] | |
| # Options for HDF5 cloud optimized reads | |
| IO_PARAMS = { | |
| "fsspec_params": { | |
| # "skip_instance_cache": True | |
| "cache_type": "blockcache", # or "first" with enough space | |
| "block_size": 2 * 1024 * 1024, # could be bigger | |
| "token": st.secrets["HF_TOKEN"], | |
| }, | |
| "h5py_params": { | |
| "driver_kwds": { # only recent versions of xarray and h5netcdf allow this correctly | |
| "page_buf_size": 2 * 1024 * 1024, # this one only works in repacked files | |
| "rdcc_nbytes": 2 * 1024 * 1024, # this one is to read the chunks | |
| } | |
| }, | |
| } | |
| # Instantiate streamlit state attributes | |
| for key in ["file", "files", "field_names", "spatial_dim", "data"]: | |
| if key not in st.session_state: | |
| st.session_state[key] = None | |
| def reset_state(key: str): | |
| if key in st.session_state: | |
| st.session_state[key] = None | |
| del st.session_state[key] | |
| def get_dataset_path(dataset_name: str) -> str: | |
| """Compose the path to the dataset on HF hub.""" | |
| repo_id = "polymathic-ai" | |
| dataset_path = f"hf://datasets/{repo_id}/{dataset_name}" | |
| return dataset_path | |
| def get_dataset_files(dataset_name: str): | |
| """Get the list of files in the dataset.""" | |
| dataset_path = get_dataset_path(dataset_name) | |
| fs, _ = fsspec.url_to_fs(dataset_path) | |
| dataset_files = fs.glob(f"{dataset_path}/**/*.hdf5") | |
| return dataset_files | |
| def get_dataset_info(file_path: str) -> tuple([int, list[str]]): | |
| """Retrive spatial dimension and field names from the dataset.""" | |
| file_path = f"hf://{file_path}" | |
| with fsspec.open(file_path, "rb") as f, h5py.File(f, "r") as file: | |
| spatial_dim = file.attrs["n_spatial_dims"] | |
| field_names = [] | |
| for field in file["t0_fields"].keys(): | |
| field_names.append((field, "t0_fields")) | |
| for field in file["t1_fields"].keys(): | |
| for _, dim_suffix in zip(range(spatial_dim), DIM_SUFFIXES): | |
| field_names.append((f"{field}_{dim_suffix}", "t1_fields")) | |
| return spatial_dim, field_names | |
| def dataset_info_callback(): | |
| dataset_name = st.session_state.name | |
| dataset_files = get_dataset_files(dataset_name) | |
| st.session_state.files = dataset_files | |
| spatial_dim, field_names = get_dataset_info(dataset_files[0]) | |
| st.session_state.spatial_dim = spatial_dim | |
| st.session_state.field_names = field_names | |
| # Field data for previous dataset must be cleared | |
| reset_state(key="data") | |
| def get_field(file_path: str, field: tuple[str, str], spatial_dim: int) -> np.ndarray: | |
| """Load the first trajectory of a field in a given file.""" | |
| file_path = f"hf://{file_path}" | |
| field_name, field_tensor_order = field | |
| if field_tensor_order == "t1_fields": | |
| field_name_splits = field_name.split("_") | |
| dim_suffix = field_name_splits[-1] | |
| dim_index = DIM_SUFFIXES.index(dim_suffix) | |
| field_name = "_".join(field_name_splits[:-1]) | |
| else: | |
| dim_index = None | |
| with ( | |
| fsspec.open(file_path, "rb", **IO_PARAMS["fsspec_params"]) as f, | |
| h5py.File(f, "r", **IO_PARAMS["h5py_params"]) as file, | |
| ): | |
| # Get the first trajectory of the file | |
| # For tensor of order 1 take the relevant spatial dimension | |
| if dim_index is not None: | |
| take_indices = (0, ..., dim_index) | |
| else: | |
| take_indices = 0 | |
| field_data = np.array(file[field_tensor_order][field_name][take_indices]) | |
| return field_data | |
| def field_callback(): | |
| """Callback to retrieve field data given file and field name state.""" | |
| file = st.session_state.get("file", None) | |
| if file: | |
| field = st.session_state.field | |
| spatial_dim = st.session_state.spatial_dim | |
| field_data = get_field(file, field, spatial_dim) | |
| st.session_state.data = field_data | |
| # The field is constant | |
| if st.session_state.data.ndim <= 2: | |
| reset_state(key="time_step") | |
| def create_plotter() -> pv.Plotter: | |
| """Create a pyvista.Plotter of the field in state.""" | |
| # Check wether the field is dynamic | |
| # to account for time in spatial dimension retrieval | |
| time_step = st.session_state.get("time_step", None) | |
| position_offset = 0 if time_step is None else 1 | |
| # Create 2D or 3D grid | |
| spatial_dim = st.session_state.spatial_dim | |
| if spatial_dim == 2: | |
| nx, ny = st.session_state.data.shape[position_offset:] | |
| xrng = np.arange(0, nx) | |
| yrng = np.arange(0, ny) | |
| grid = pv.RectilinearGrid(xrng, yrng) | |
| elif spatial_dim == 3: | |
| nx, ny, nz = st.session_state.data.shape[position_offset:] | |
| xrng = np.arange(0, nx) | |
| yrng = np.arange(0, ny) | |
| zrng = np.arange(0, nz) | |
| grid = pv.RectilinearGrid(xrng, yrng, zrng) | |
| # Set the grid scalar field | |
| # If no time step is set the field is assumed to be constant | |
| field_name = st.session_state.field[0] | |
| if time_step is None: | |
| grid[field_name] = st.session_state.data.ravel() | |
| else: | |
| grid[field_name] = st.session_state.data[time_step].ravel() | |
| plotter = pv.Plotter(window_size=[400, 400]) | |
| plotter.add_mesh(grid, scalars=field_name) | |
| if spatial_dim == 2: | |
| plotter.view_xy() | |
| elif spatial_dim == 3: | |
| plotter.view_isometric() | |
| plotter.background_color = "white" | |
| return plotter | |
| st.set_page_config( | |
| page_title="Tap into the Well", page_icon="assets/the_well_color_icon.svg" | |
| ) | |
| st.image("assets/the_well_logo.png") | |
| st.markdown(""" | |
| [The Well](https://openreview.net/pdf?id=00Sx577BT3) is a collection of 15TB datasets of physics simulations. | |
| This space allows you to tap into the Well by visualizing different datasets hosted on the [Hugging Face Hub](https://huggingface.co/polymathic-ai). | |
| - Select a dataset | |
| - Select a field | |
| - Select a file | |
| - Visualize different time steps | |
| For field corresponding of higher tensor order (e.g. velocity) loading the data may be slow. | |
| For this reason, we recommend downloading the data to work on the Well. | |
| Check the [documentation](the-well.polymathic-ai.org) for more information. | |
| """) | |
| # The order of the following widget matters | |
| # Field data is updated whenever a file or a field is selected | |
| # Dataset selection | |
| dataset = st.selectbox( | |
| "Select a Dataset", | |
| options=DATASET_NAMES, | |
| index=None, | |
| key="name", | |
| on_change=dataset_info_callback, | |
| ) | |
| # File selection | |
| if st.session_state.name: | |
| field_selector = st.selectbox( | |
| "Select a field", | |
| key="field", | |
| options=st.session_state.field_names, | |
| format_func=lambda option: option[0], # Fields are (name, tensor_order) | |
| on_change=field_callback, | |
| ) | |
| file_selector = st.selectbox( | |
| "Select a file", | |
| options=st.session_state.files, | |
| key="file", | |
| index=None, | |
| format_func=lambda option: pathlib.Path(option).name, | |
| on_change=field_callback, | |
| ) | |
| if st.session_state.data is not None: | |
| # Add a time step slider for dynamic fields | |
| if st.session_state.data.ndim > 2: | |
| time_step_slider = st.slider( | |
| "Time step", | |
| min_value=0, | |
| value=0, | |
| max_value=st.session_state.data.shape[0] - 1, | |
| key="time_step", | |
| ) | |
| if st.session_state.data is not None: | |
| plotter = create_plotter() | |
| stpyvista(plotter) | |