Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import base64 | |
| import gzip | |
| import json | |
| from dataclasses import dataclass, fields | |
| from io import BytesIO | |
| from pathlib import Path | |
| from urllib.parse import parse_qsl | |
| import altair as alt | |
| import ipywidgets as widgets | |
| import numpy as np | |
| import pandas as pd | |
| import solara | |
| import solara.lab | |
| from cmap import Colormap | |
| from ipymolstar.widget import PDBeMolstar | |
| from pydantic import BaseModel | |
| from make_link import decode_data | |
| base_v = np.vectorize(np.base_repr) | |
| PAD_SIZE = 0.05 # when not autoscale Y size of padding used | |
| def norm(x, vmin, vmax): | |
| return (x - vmin) / (vmax - vmin) | |
| class ColorTransform(BaseModel): | |
| name: str = "tol:rainbow_PuRd" | |
| norm_type: str = "linear" | |
| vmin: float = 0.0 | |
| vmax: float = 1.0 | |
| missing_data_color: str = "#8c8c8c" | |
| highlight_color: str = "#e933f8" | |
| def molstar_colors(self, data: pd.DataFrame) -> dict: | |
| if self.norm_type == "categorical": | |
| values = data["value"] | |
| else: | |
| values = norm(data["value"], vmin=self.vmin, vmax=self.vmax) | |
| rgba_array = self.cmap(values, bytes=True) | |
| ints = rgba_array.astype(np.uint8).view(dtype=np.uint32).byteswap() | |
| padded = np.char.rjust(base_v(ints // 2**8, 16), 6, "0") | |
| hex_colors = np.char.add("#", padded).squeeze() | |
| color_data = { | |
| "data": [ | |
| {"residue_number": resi, "color": hcolor.lower()} | |
| for resi, hcolor in zip(data["residue_number"], hex_colors) | |
| ], | |
| "nonSelectedColor": self.missing_data_color, | |
| } | |
| return color_data | |
| def cmap(self) -> Colormap: | |
| return Colormap(self.name, bad=self.missing_data_color) | |
| def altair_scale(self) -> alt.Scale: | |
| if self.norm_type == "categorical": | |
| colors = self.cmap.to_altair(N=self.cmap.num_colors) | |
| domain = range(self.cmap.num_colors) | |
| else: | |
| colors = self.cmap.to_altair() | |
| domain = np.linspace(self.vmin, self.vmax, 256, endpoint=True) | |
| scale = alt.Scale(domain=list(domain), range=colors, clamp=True) | |
| return scale | |
| class AxisProperties(BaseModel): | |
| label: str = "x" | |
| unit: str = "au" | |
| autoscale_y: bool = True | |
| def title(self) -> str: | |
| return f"{self.label} ({self.unit})" | |
| def make_chart( | |
| data: pd.DataFrame, colors: ColorTransform, axis_properties: AxisProperties | |
| ) -> alt.LayerChart: | |
| xmin, xmax = data["residue_number"].min(), data["residue_number"].max() | |
| xpad = (xmax - xmin) * 0.05 | |
| xscale = alt.Scale(domain=(xmin - xpad, xmax + xpad)) | |
| if axis_properties.autoscale_y: | |
| y_scale = alt.Scale() | |
| elif colors.norm_type == "categorical": | |
| ypad = colors.cmap.num_colors * 0.05 | |
| y_scale = alt.Scale(domain=(0 - ypad, colors.cmap.num_colors - 1 + ypad)) | |
| else: | |
| ypad = (colors.vmax - colors.vmin) * 0.05 | |
| y_scale = alt.Scale(domain=(colors.vmin - ypad, colors.vmax + ypad)) | |
| zoom_x = alt.selection_interval( | |
| bind="scales", | |
| encodings=["x"], | |
| zoom="wheel![!event.shiftKey]", | |
| ) | |
| scatter = ( | |
| alt.Chart(data) | |
| .mark_circle(interpolate="basis", size=200) | |
| .encode( | |
| x=alt.X("residue_number:Q", title="Residue Number", scale=xscale), | |
| y=alt.Y( | |
| "value:Q", | |
| title=axis_properties.title, | |
| scale=y_scale, | |
| ), | |
| color=alt.Color( | |
| f"value:{'O' if colors.norm_type == 'categorical' else 'Q'}", | |
| scale=colors.altair_scale, | |
| title=axis_properties.title, | |
| ), | |
| ) | |
| .add_params(zoom_x) | |
| ) | |
| # Create a selection that chooses the nearest point & selects based on x-value | |
| nearest = alt.selection_point( | |
| name="point", | |
| nearest=True, | |
| on="pointerover", | |
| fields=["residue_number"], | |
| empty=False, | |
| clear="mouseout", | |
| ) | |
| select_residue = ( | |
| alt.Chart(data) | |
| .mark_point() | |
| .encode( | |
| x="residue_number:Q", | |
| opacity=alt.value(0), | |
| ) | |
| .add_params(nearest) | |
| ) | |
| # Draw a rule at the location of the selection | |
| rule = ( | |
| alt.Chart(data) | |
| .mark_rule(color=colors.highlight_color, size=2) | |
| .encode( | |
| x="residue_number:Q", | |
| ) | |
| .transform_filter(nearest) | |
| ) | |
| # vline = ( | |
| # alt.Chart(pd.DataFrame({"x": [0]})) | |
| # .mark_rule(color=colors.highlight_color, size=2) | |
| # .encode(x="x:Q") | |
| # ) | |
| line_position = alt.param(name="line_position", value=0.0) | |
| line_opacity = alt.param(name="line_opacity", value=1) | |
| df_line = pd.DataFrame({"x": [1.0]}) | |
| # Create vertical rule with parameter | |
| vline = ( | |
| alt.Chart(df_line) | |
| .mark_rule(color=colors.highlight_color, opacity=line_opacity, size=2) | |
| .encode(x=alt.X("p", type="quantitative")) | |
| .transform_calculate(p=alt.datum.x * line_position) | |
| .add_params(line_position, line_opacity) | |
| ) | |
| # Put the five layers into a chart and bind the data | |
| chart = ( | |
| alt.layer(scatter, vline, select_residue, rule).properties( | |
| width="container", | |
| height=480, # autosize height? | |
| ) | |
| # .configure(autosize="fit") | |
| ) | |
| return chart | |
| def ScatterChart( | |
| data: pd.DataFrame, | |
| colors: ColorTransform, | |
| axis_properties: AxisProperties, | |
| on_selections, | |
| line_value, | |
| ): | |
| def mem_chart(): | |
| chart = make_chart(data, colors, axis_properties) | |
| return chart | |
| chart = solara.use_memo(mem_chart, dependencies=[data, colors, axis_properties]) | |
| if line_value is not None: | |
| params = {"line_position": line_value, "line_opacity": 1} | |
| else: | |
| params = {"line_position": 0.0, "line_opacity": 0} | |
| dark_effective = solara.lab.use_dark_effective() | |
| if dark_effective: | |
| options = {"actions": False, "theme": "dark"} | |
| else: | |
| options = {"actions": False} | |
| view = alt.JupyterChart.element( # type: ignore | |
| chart=chart, | |
| embed_options=options, | |
| _params=params, | |
| ) | |
| def bind(): | |
| real = solara.get_widget(view) | |
| real.selections.observe(on_selections, "point") # type: ignore | |
| solara.use_effect(bind, [data, colors]) | |
| def ProteinView( | |
| title: str, | |
| molecule_id: str, | |
| data: pd.DataFrame, | |
| colors: ColorTransform, | |
| axis_properties: AxisProperties, | |
| dark_effective: bool, | |
| description: str = "", | |
| ): | |
| about_dialog = solara.use_reactive(False) | |
| fullscreen = solara.use_reactive(False) | |
| # residue number to highlight in altair chart | |
| line_number = solara.use_reactive(None) | |
| # residue number to highlight in protein view | |
| highlight_number = solara.use_reactive(None) | |
| if data.empty: | |
| color_data = {} | |
| else: | |
| color_data = colors.molstar_colors(data) | |
| tooltips = { | |
| "data": [ | |
| { | |
| "residue_number": resi, | |
| "tooltip": f"{axis_properties.label}: {value:.2g} {axis_properties.unit}" | |
| if not np.isnan(value) | |
| else "No data", | |
| } | |
| for resi, value in zip(data["residue_number"], data["value"]) | |
| ] | |
| } | |
| def on_molstar_mouseover(value): | |
| r = value.get("residueNumber", None) | |
| line_number.set(r) | |
| def on_molstar_mouseout(value): | |
| on_molstar_mouseover({}) | |
| def on_chart_selection(event): | |
| try: | |
| r = event["new"].value[0]["residue_number"] | |
| highlight_number.set(r) | |
| except (IndexError, KeyError): | |
| highlight_number.set(None) | |
| with solara.AppBar(): | |
| solara.AppBarTitle(title) | |
| with solara.Tooltip("Fullscreen"): | |
| solara.Button( | |
| icon_name="mdi-fullscreen", | |
| icon=True, | |
| on_click=lambda: fullscreen.set(not fullscreen.value), | |
| ) | |
| if description: | |
| with solara.Tooltip("About"): | |
| solara.Button( | |
| icon_name="mdi-information-outline", | |
| icon=True, | |
| on_click=lambda: about_dialog.set(True), | |
| ) | |
| solara.lab.ThemeToggle() | |
| with solara.v.Dialog( | |
| v_model=about_dialog.value, on_v_model=lambda _ignore: about_dialog.set(False) | |
| ): | |
| with solara.Card("About", margin=0): | |
| solara.Markdown(description) | |
| with solara.ColumnsResponsive([4, 8]): | |
| with solara.Card(style={"height": "550px"}): | |
| PDBeMolstar.element( # type: ignore | |
| theme="dark" if dark_effective else "light", | |
| molecule_id=molecule_id.lower(), | |
| color_data=color_data, | |
| hide_water=True, | |
| tooltips=tooltips, | |
| height="525px", | |
| highlight={"data": [{"residue_number": int(highlight_number.value)}]} | |
| if highlight_number.value | |
| else None, | |
| highlight_color=colors.highlight_color, | |
| on_mouseover_event=on_molstar_mouseover, | |
| on_mouseout_event=on_molstar_mouseout, | |
| hide_controls_icon=True, | |
| hide_expand_icon=True, | |
| hide_settings_icon=True, | |
| expanded=fullscreen.value, | |
| ).key(f"molstar-{dark_effective}") | |
| if not fullscreen.value: | |
| with solara.Card(style={"height": "550px"}): | |
| if data.empty: | |
| solara.Text("No data") | |
| else: | |
| ScatterChart( | |
| data, | |
| colors, | |
| axis_properties, | |
| on_chart_selection, | |
| line_number.value, | |
| ) | |
| def RoutedView(): | |
| route = solara.use_router() | |
| dark_effective = solara.lab.use_dark_effective() | |
| try: | |
| query_dict = {k: v for k, v in parse_qsl(route.search)} | |
| colors = ColorTransform(**query_dict) # type: ignore | |
| axis_properties = AxisProperties(**query_dict) # type: ignore | |
| data = decode_data(query_dict["data"]) | |
| ProteinView( | |
| query_dict["title"], | |
| molecule_id=query_dict["molecule_id"], | |
| data=data, | |
| colors=colors, | |
| axis_properties=axis_properties, | |
| dark_effective=dark_effective, | |
| description=query_dict.get("description", ""), | |
| ) | |
| except KeyError as err: | |
| solara.Warning(f"Error: {err}") | |
| def Page(): | |
| dark_effective = solara.lab.use_dark_effective() | |
| dark_previous = solara.use_previous(dark_effective) | |
| if dark_previous != dark_effective: | |
| if dark_effective: | |
| alt.themes.enable("dark") | |
| else: | |
| alt.themes.enable("default") | |
| solara.Style( | |
| """ | |
| .vega-embed { | |
| overflow: visible; | |
| width: 100% !important; | |
| }""" | |
| ) | |
| settings = json.loads(Path("settings.json").read_text()) | |
| colors = ColorTransform(**settings) | |
| axis_properties = AxisProperties(**settings) | |
| data = pd.read_csv("color_data.csv") | |
| ProteinView( | |
| settings["title"], | |
| molecule_id=settings["molecule_id"], | |
| data=data, | |
| colors=colors, | |
| axis_properties=axis_properties, | |
| dark_effective=dark_effective, | |
| ) | |