Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import os | |
| from typing import List | |
| from langchain_core.documents import Document | |
| from langchain_core.vectorstores import VectorStore | |
| from langchain_milvus import Milvus | |
| from sandbox.light_rag.hf_embedding import HFEmbedding | |
| from sandbox.light_rag.hf_llm import HFLLM | |
| context_template = "Document:\n{document}\n" | |
| token_limit = 4096 | |
| logger = logging.getLogger() | |
| class LightRAG: | |
| def __init__(self, config: dict): | |
| self.config = config | |
| lazy_loading = os.environ.get("LAZY_LOADING") | |
| self.gen_model = None if lazy_loading else HFLLM(config['generation_model_id']) | |
| self._embedding_model = None if lazy_loading else HFEmbedding(config['embedding_model_id']) | |
| # self._vector_store = None | |
| self._pre_cached_indices = {} | |
| # now lazy: | |
| # Milvus( | |
| # embedding_function=self._embedding_model, | |
| # collection_name=config['milvus_collection_name'].replace("-", "_"), | |
| # index_params={"metric_ttpe": "cosine".upper()}, | |
| # # connection_args = ({"uri": "./milvus/text/milvus.db"}) | |
| # connection_args = ({"uri": config['milvus_db_path']}) | |
| # ) | |
| def _get_embedding_model(self): | |
| if self._embedding_model is None: | |
| self._embedding_model = HFEmbedding(self.config['embedding_model_id']) | |
| return self._embedding_model | |
| def precache_milvus(self, collection, db): | |
| # col_name = self.config["milvus_collection_name"] if collection is None else collection | |
| # db = self.config["milvus_db_path"] if db is None else db | |
| key = self._cache_key(collection, db) | |
| self._pre_cached_indices[key] = Milvus( | |
| embedding_function=self._get_embedding_model(), | |
| collection_name=collection.replace("-", "_"), | |
| index_params={"metric_ttpe": "cosine".upper()}, | |
| # connection_args = ({"uri": "./milvus/text/milvus.db"}) | |
| connection_args=({"uri": db}), | |
| ) | |
| def _get_milvus_index(self, collection, db): | |
| key = self._cache_key(collection, db) | |
| if key in self._pre_cached_indices: | |
| print(f"cache hit: {key}") | |
| return self._pre_cached_indices[key] | |
| else: | |
| return Milvus( | |
| embedding_function=self._get_embedding_model(), | |
| collection_name=collection.replace("-", "_"), | |
| index_params={"metric_ttpe": "cosine".upper()}, | |
| # connection_args = ({"uri": "./milvus/text/milvus.db"}) | |
| connection_args=({"uri": db}), | |
| ) | |
| def search(self, query: str, top_n: int = 5, collection=None, db=None) -> list[Document]: | |
| # if self._vector_store is None: | |
| # TODO: be more clever :) | |
| col_name = self.config["milvus_collection_name"] if collection is None else collection | |
| db = self.config["milvus_db_path"] if db is None else db | |
| # print(f"col_name: {col_name} on db: {db}") | |
| vs = self._get_milvus_index(col_name, db) | |
| # self._vector_store = Milvus( | |
| # embedding_function=self._get_embedding_model(), | |
| # collection_name=col_name.replace("-", "_"), | |
| # index_params={"metric_ttpe": "cosine".upper()}, | |
| # # connection_args = ({"uri": "./milvus/text/milvus.db"}) | |
| # connection_args=({"uri": db}), | |
| # ) | |
| context = vs.similarity_search( | |
| query=query, | |
| k=100, | |
| ) | |
| results = [] | |
| for d in context: | |
| if d.metadata.get("type") == "text": # and not ("Picture placeholder" in d.page_content): | |
| results.append(d) | |
| elif d.metadata.get("type") == "image_description": | |
| if not any(r.metadata["document_id"] == d.metadata.get("document_id") for r in results): | |
| results.append(d) | |
| top_n = min(top_n, len(results)) | |
| return results[:top_n] | |
| def _build_prompt(self, question: str, context: List[Document]): | |
| # Prepare documents: | |
| text_documents = [] | |
| for doc in context: | |
| if doc.metadata['type'] == 'text': | |
| text_documents.append(doc.page_content.strip()) | |
| elif doc.metadata['type'] == 'image_description': | |
| text_documents.append(doc.metadata['image_description'].strip()) | |
| else: | |
| logger.warning('Should not get here!') | |
| documents = [{"text": x} for x in text_documents] | |
| prompt = self.gen_model.tokenizer.apply_chat_template( | |
| conversation=[ | |
| { | |
| "role": "user", | |
| "content": question, | |
| } | |
| ], | |
| documents=documents, # This uses the documents support in the Granite chat template | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| ) | |
| return prompt | |
| def generate(self, query, context=None): | |
| if self.gen_model is None: | |
| self.gen_model = HFLLM(self.config["generation_model_id"]) | |
| # build prompt | |
| question = query | |
| prompt = self._build_prompt(question, context) | |
| # print(f"prompt: |||{prompt}|||") | |
| # infer | |
| results = self.gen_model.generate(prompt) | |
| # print(f"results: {results}") | |
| answer = results[0]["answer"] | |
| return answer, prompt | |
| def _cache_key(self, collection, db): | |
| return collection + "___" + db | |
| # if __name__ == '__main__': | |
| # from dotenv import load_dotenv | |
| # load_dotenv() | |
| # | |
| # config = { | |
| # "embedding_model_id": "ibm-granite/granite-embedding-125m-english", | |
| # "generation_model_id": "ibm-granite/granite-3.1-8b-instruct", | |
| # "milvus_collection_name": "granite_vision_tech_report_text_milvus_lite_512_128_slate_125m_cosine", | |
| # "milvus_db_path": "/dccstor/mm-rag/adi/code/RAGEval/milvus/text/milvus.db" | |
| # } | |
| # | |
| # rag_app = LightRAG(config) | |
| # | |
| # query = "What models are available in Watsonx?" | |
| # | |
| # # run retrieval | |
| # context = rag_app.search(query=query, top_n=5) | |
| # # generate answers | |
| # answer, prompt = rag_app.generate(query=query, context=context) | |
| # | |
| # print(f"Answer:\n{answer}") | |
| # print(f"Used prompt:\n{prompt}") | |
| # python -m debugpy --connect cccxl009.pok.ibm.com:3002 ./sandbox/light_rag/light_rag.py | |