Spaces:
Sleeping
Sleeping
use num_proc for loading
Browse files
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from functools import lru_cache
|
| 2 |
|
| 3 |
import duckdb
|
|
@@ -14,7 +15,7 @@ model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
|
|
| 14 |
model = SentenceTransformer(
|
| 15 |
model_name,
|
| 16 |
device="cpu",
|
| 17 |
-
tokenizer_kwargs={"model_max_length": 512},
|
| 18 |
)
|
| 19 |
|
| 20 |
|
|
@@ -35,7 +36,7 @@ def get_iframe(hub_repo_id):
|
|
| 35 |
|
| 36 |
def load_dataset_from_hub(hub_repo_id: str):
|
| 37 |
gr.Info(message="Loading dataset...")
|
| 38 |
-
ds = load_dataset(hub_repo_id)
|
| 39 |
|
| 40 |
|
| 41 |
def get_columns(hub_repo_id: str, split: str):
|
|
@@ -50,7 +51,7 @@ def get_columns(hub_repo_id: str, split: str):
|
|
| 50 |
|
| 51 |
|
| 52 |
def get_splits(hub_repo_id: str):
|
| 53 |
-
ds = load_dataset(hub_repo_id)
|
| 54 |
splits = list(ds.keys())
|
| 55 |
return gr.Dropdown(
|
| 56 |
choices=splits, value=splits[0], label="Select a split", visible=True
|
|
@@ -60,7 +61,7 @@ def get_splits(hub_repo_id: str):
|
|
| 60 |
@lru_cache
|
| 61 |
def vectorize_dataset(hub_repo_id: str, split: str, column: str):
|
| 62 |
gr.Info("Vectorizing dataset...")
|
| 63 |
-
ds = load_dataset(hub_repo_id)
|
| 64 |
df = ds[split].to_polars()
|
| 65 |
embeddings = model.encode(df[column].cast(str).to_list(), show_progress_bar=True, batch_size=128)
|
| 66 |
return embeddings
|
|
@@ -68,7 +69,7 @@ def vectorize_dataset(hub_repo_id: str, split: str, column: str):
|
|
| 68 |
|
| 69 |
def run_query(hub_repo_id: str, query: str, split: str, column: str):
|
| 70 |
embeddings = vectorize_dataset(hub_repo_id, split, column)
|
| 71 |
-
ds = load_dataset(hub_repo_id)
|
| 72 |
df = ds[split].to_polars()
|
| 73 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
| 74 |
try:
|
|
|
|
| 1 |
+
import os
|
| 2 |
from functools import lru_cache
|
| 3 |
|
| 4 |
import duckdb
|
|
|
|
| 15 |
model = SentenceTransformer(
|
| 16 |
model_name,
|
| 17 |
device="cpu",
|
| 18 |
+
tokenizer_kwargs={"model_max_length": 512}, # arbitrary for this model, here to keep things fast
|
| 19 |
)
|
| 20 |
|
| 21 |
|
|
|
|
| 36 |
|
| 37 |
def load_dataset_from_hub(hub_repo_id: str):
|
| 38 |
gr.Info(message="Loading dataset...")
|
| 39 |
+
ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
|
| 40 |
|
| 41 |
|
| 42 |
def get_columns(hub_repo_id: str, split: str):
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
def get_splits(hub_repo_id: str):
|
| 54 |
+
ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
|
| 55 |
splits = list(ds.keys())
|
| 56 |
return gr.Dropdown(
|
| 57 |
choices=splits, value=splits[0], label="Select a split", visible=True
|
|
|
|
| 61 |
@lru_cache
|
| 62 |
def vectorize_dataset(hub_repo_id: str, split: str, column: str):
|
| 63 |
gr.Info("Vectorizing dataset...")
|
| 64 |
+
ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
|
| 65 |
df = ds[split].to_polars()
|
| 66 |
embeddings = model.encode(df[column].cast(str).to_list(), show_progress_bar=True, batch_size=128)
|
| 67 |
return embeddings
|
|
|
|
| 69 |
|
| 70 |
def run_query(hub_repo_id: str, query: str, split: str, column: str):
|
| 71 |
embeddings = vectorize_dataset(hub_repo_id, split, column)
|
| 72 |
+
ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
|
| 73 |
df = ds[split].to_polars()
|
| 74 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
| 75 |
try:
|