Spaces:
Runtime error
Runtime error
| from langchain.llms.huggingface_pipeline import HuggingFacePipeline | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| # Set logging for the queries | |
| import logging | |
| logging.basicConfig() | |
| class MultiQueryDocumentRetriever: | |
| def __init__(self, vector_store): | |
| self.vector_store = vector_store | |
| self.retriever = None | |
| self.llm = None | |
| # self.token = "LL-1kuyxK1z5NQYOiOsf5UdozHJuLhV6udoDGxL8NfM7brWCUbF0uqlii15sso8GNrd" | |
| def initialize(self): | |
| # self.llama = LlamaAPI(self.token) | |
| self.llm = HuggingFacePipeline.from_model_id( | |
| # model_id="bigscience/bloom-1b7", | |
| model_id="bigscience/bloomz-1b7", | |
| task="text-generation", | |
| # device=1, | |
| # model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2}, | |
| model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2}, | |
| # pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30}, | |
| pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30}, | |
| ) | |
| logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO) | |
| self.retriever = MultiQueryRetriever.from_llm( | |
| retriever=self.vector_store.db.as_retriever(search_kwargs={"k": 4, "fetch_k": 40}), | |
| llm=self.llm | |
| ) | |
| def retrieve(self, query: str, k: int = 4): | |
| pass |