Spaces:
Running
Running
| import inspect | |
| import json | |
| from typing import Any, Dict, Optional | |
| from keyword_extraction import KeywordExtractor | |
| from langchain.callbacks.manager import CallbackManagerForChainRun | |
| from langchain.chains.conversational_retrieval.base import ( | |
| ConversationalRetrievalChain, | |
| _get_chat_history, | |
| ) | |
| from langchain.schema import Document | |
| class CustomConversationalRetrievalChain(ConversationalRetrievalChain): | |
| keyword_extractor: KeywordExtractor = KeywordExtractor() | |
| def _handle_docs(self, docs): | |
| if len(docs) == 0: | |
| return False, "No documents found. Can you rephrase ?" | |
| elif len(docs) == 1: | |
| return False, "Only one document found. Can you rephrase ?" | |
| elif len(docs) > 10: | |
| return False, "Too many documents found. Can you specify your request ?" | |
| return True, "" | |
| def rerank_documents(self, question: str, docs: list[Document]) -> list[Document]: | |
| """Rerank documents based on the number of similar keywords | |
| Args: | |
| question (str): Orinal question | |
| docs (list[Document]): List of documents | |
| Returns: | |
| list[Document]: List of documents sorted by the number of similar keywords | |
| """ | |
| keywords = self.keyword_extractor(question) | |
| for doc in docs: | |
| doc.metadata["similar_keyword"] = 0 | |
| doc_keywords = json.loads(doc.page_content)["keywords"] | |
| if doc_keywords is None: | |
| continue | |
| doc_keywords = doc_keywords.lower().split(",") | |
| for kw in keywords: | |
| if kw.lower() in doc_keywords: | |
| doc.metadata["similar_keyword"] += 1 | |
| print("similar keyword : ", kw) | |
| docs = sorted(docs, key=lambda x: x.metadata["similar_keyword"]) | |
| return docs | |
| def _call( | |
| self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[CallbackManagerForChainRun] = None, | |
| ) -> Dict[str, Any]: | |
| _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
| question = inputs["question"] | |
| get_chat_history = self.get_chat_history or _get_chat_history | |
| chat_history_str = get_chat_history(inputs["chat_history"]) | |
| if chat_history_str: | |
| callbacks = _run_manager.get_child() | |
| new_question = self.question_generator.run( | |
| question=question, chat_history=chat_history_str, callbacks=callbacks | |
| ) | |
| else: | |
| new_question = question | |
| accepts_run_manager = ( | |
| "run_manager" in inspect.signature(self._get_docs).parameters | |
| ) | |
| if accepts_run_manager: | |
| docs = self._get_docs(new_question, inputs, run_manager=_run_manager) | |
| else: | |
| docs = self._get_docs(new_question, inputs) # type: ignore[call-arg] | |
| valid_docs, message = self._handle_docs(docs) | |
| if not valid_docs: | |
| return { | |
| self.output_key: message, | |
| "source_documents": docs, | |
| } | |
| # Add reranking | |
| docs = self.rerank_documents(new_question, docs) | |
| new_inputs = inputs.copy() | |
| if self.rephrase_question: | |
| new_inputs["question"] = new_question | |
| new_inputs["chat_history"] = chat_history_str | |
| answer = self.combine_docs_chain.run( | |
| input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs | |
| ) | |
| output: Dict[str, Any] = {self.output_key: answer} | |
| if self.return_source_documents: | |
| output["source_documents"] = docs | |
| if self.return_generated_question: | |
| output["generated_question"] = new_question | |
| return output | |