Spaces:
Running
on
Zero
Running
on
Zero
| from typing import List | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from sandbox.light_rag.utils import get_device | |
| class HFEmbedding: | |
| def __init__( | |
| self, | |
| model_id: str, | |
| ): | |
| device = get_device() | |
| # TODO: hack for zeroGPU | |
| device = "cpu" | |
| print(f"Using device: {device}") | |
| if device == "cpu": | |
| print("Using CPU might be too slow") | |
| self.model_name = model_id | |
| print(f"Loading embeddings model from: {self.model_name}") | |
| self.embeddings_service = HuggingFaceEmbeddings( | |
| model_name=self.model_name, | |
| model_kwargs={"device": device}, | |
| ) | |
| def embed_batch(self, batch: list[str]): | |
| return self.embeddings_service.embed_documents(batch) | |
| def embed_documents(self, texts: list[str]) -> list[list[float]]: | |
| embeddings = self.embeddings_service.embed_documents(texts) | |
| return embeddings | |
| def embed_query(self, text: str) -> list[float]: | |
| return self.embed_documents([text])[0] | |