Spaces:
Runtime error
Runtime error
| import logging | |
| from functools import partial | |
| from typing import Callable, List, Optional | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import umap | |
| from bokeh.models import ColumnDataSource, HoverTool | |
| from bokeh.palettes import Cividis256 as Pallete | |
| from bokeh.plotting import Figure, figure | |
| from bokeh.transform import factor_cmap | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.manifold import TSNE | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| EMBEDDING_MODELS = ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2", "flax-sentence-embeddings/all_datasets_v3_mpnet-base"] | |
| DIMENSIONALITY_REDUCTION_ALGORITHMS = ["UMAP", "t-SNE"] | |
| SEED = 0 | |
| def load_model(model_name: str) -> SentenceTransformer: | |
| embedder = model_name | |
| return SentenceTransformer(embedder) | |
| def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray: | |
| return model.encode(text) | |
| def encode_labels(labels: pd.Series) -> pd.Series: | |
| if pd.api.types.is_numeric_dtype(labels): | |
| return labels | |
| return labels.astype("category").cat.codes | |
| def get_tsne_embeddings( | |
| embeddings: np.ndarray, perplexity: int = 30, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED | |
| ) -> np.ndarray: | |
| tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state) | |
| return tsne.fit_transform(embeddings) | |
| def get_umap_embeddings(embeddings: np.ndarray) -> np.ndarray: | |
| umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=SEED) | |
| return umap_model.fit_transform(embeddings) | |
| def draw_interactive_scatter_plot( | |
| texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str | |
| ) -> Figure: | |
| # Normalize values to range between 0-255, to assign a color for each value | |
| max_value = values.max() | |
| min_value = values.min() | |
| if max_value - min_value == 0: | |
| values_color = np.ones(len(values)) | |
| else: | |
| values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str) | |
| values_color_set = sorted(values_color) | |
| values_list = values.astype(str).tolist() | |
| values_set = sorted(values_list) | |
| labels_list = labels.astype(str).tolist() | |
| source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list)) | |
| hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]) | |
| p = figure(plot_width=800, plot_height=800, tools=[hover]) | |
| p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set)) | |
| p.axis.visible = False | |
| p.xgrid.grid_line_color = None | |
| p.ygrid.grid_line_color = None | |
| p.toolbar.logo = None | |
| return p | |
| def uploaded_file_to_dataframe(uploaded_file: st.uploaded_file_manager.UploadedFile) -> pd.DataFrame: | |
| extension = uploaded_file.name.split(".")[-1] | |
| return pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",") | |
| def hub_dataset_to_dataframe(path: str, name: str, split: str, sample: int) -> pd.DataFrame: | |
| load_dataset_fn = partial(load_dataset, path=path) | |
| if name: | |
| load_dataset_fn = partial(load_dataset_fn, name=name) | |
| if split: | |
| load_dataset_fn = partial(load_dataset_fn, split=split) | |
| dataset = load_dataset_fn().shuffle(seed=SEED)[:sample] | |
| return pd.DataFrame(dataset) | |
| def generate_plot( | |
| df: pd.DataFrame, | |
| text_column: str, | |
| label_column: str, | |
| sample: Optional[int], | |
| dimensionality_reduction_function: Callable, | |
| model: SentenceTransformer, | |
| ) -> Figure: | |
| if text_column not in df.columns: | |
| raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}") | |
| if label_column not in df.columns: | |
| df[label_column] = 0 | |
| df = df.dropna(subset=[text_column, label_column]) | |
| if sample: | |
| df = df.sample(min(sample, df.shape[0]), random_state=SEED) | |
| with st.spinner(text="Embedding text..."): | |
| embeddings = embed_text(df[text_column].values.tolist(), model) | |
| logger.info("Encoding labels") | |
| encoded_labels = encode_labels(df[label_column]) | |
| with st.spinner("Reducing dimensionality..."): | |
| embeddings_2d = dimensionality_reduction_function(embeddings) | |
| logger.info("Generating figure") | |
| plot = draw_interactive_scatter_plot( | |
| df[text_column].values, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column | |
| ) | |
| return plot | |
| st.title("Embedding Lenses") | |
| st.write("Visualize text embeddings in 2D using colors for continuous or categorical labels.") | |
| uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"]) | |
| st.write("Alternatively, select a dataset from the [hub](https://huggingface.co/datasets)") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| hub_dataset = st.text_input("Dataset name", "ag_news") | |
| with col2: | |
| hub_dataset_config = st.text_input("Dataset configuration", "") | |
| with col3: | |
| hub_dataset_split = st.text_input("Dataset split", "train") | |
| text_column = st.text_input("Text column name", "text") | |
| label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label") | |
| sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000) | |
| dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0) | |
| model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0) | |
| with st.spinner(text="Loading model..."): | |
| model = load_model(model_name) | |
| dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings | |
| if uploaded_file or hub_dataset: | |
| with st.spinner("Loading dataset..."): | |
| if uploaded_file: | |
| df = uploaded_file_to_dataframe(uploaded_file) | |
| else: | |
| df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample) | |
| plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model) | |
| logger.info("Displaying plot") | |
| st.bokeh_chart(plot) | |
| logger.info("Done") | |