Spaces:
Runtime error
Runtime error
| """ dspy_utils.py | |
| Utilities for building a DSPy based retrieval (augmented) generation model. | |
| :author: Didier Guillevic | |
| :email: [email protected] | |
| :creation: 2024-12-21 | |
| """ | |
| import os | |
| import dspy | |
| from ragatouille import RAGPretrainedModel | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| class DSPyRagModel: | |
| def __init__(self, retrieval_model: RAGPretrainedModel): | |
| # Init the retrieval and language model | |
| self.retrieval_model = retrieval_model | |
| self.language_model = dspy.LM(model="mistral/mistral-large-latest", api_key=os.environ["MISTRAL_API_KEY"]) | |
| # Set dspy retrieval and language model | |
| dspy.settings.configure( | |
| lm=self.language_model, | |
| rm=self.retrieval_model | |
| ) | |
| # Set dspy prediction functions | |
| class BasicQA(dspy.Signature): | |
| """Answer the question given the context provided""" | |
| context = dspy.InputField(desc="may contain relevant facts") | |
| question = dspy.InputField() | |
| answer = dspy.OutputField(desc="Answer the given question.") | |
| self.predict = dspy.Predict(BasicQA, temperature=0.01) | |
| self.predict_chain_of_thought = dspy.ChainOfThought(BasicQA) | |
| def generate_response( | |
| self, | |
| question: str, | |
| k: int=3, | |
| method: str = 'chain_of_thought' | |
| ) -> tuple[str, str, str]: | |
| """Generate a response to a given question using the specified method. | |
| Args: | |
| question: the question to answer | |
| k: number of passages to retrieve | |
| method: method for generating the response: ['simple', 'chain_of_thought'] | |
| Returns: | |
| - the generated answer | |
| - (html string): the references (origin of the snippets of text used to generate the answer) | |
| - (html string): the snippets of text used to generate the answer | |
| """ | |
| # Retrieval | |
| retrieval_results = self.retrieval_model.search(query=question, k=k) | |
| passages = [res.get('content') for res in retrieval_results] | |
| metadatas = [res.get('document_metadata') for res in retrieval_results] | |
| # Generate response given retrieved passages | |
| if method == 'simple': | |
| response = self.predict(context=passages, question=question).answer | |
| elif method == 'chain_of_thought': | |
| response = self.predict_chain_of_thought(context=passages, question=question).answer | |
| else: | |
| raise ValueError(f"Unknown method: {method}. Expected ['simple', 'chain_of_thought']") | |
| # Create an HTML string with the references | |
| references = "<h4>References</h4>\n" + create_bulleted_list(metadatas) | |
| snippets = "<h4>Snippets</h4>\n" + create_bulleted_list(passages) | |
| return response, references, snippets | |
| def create_bulleted_list(texts: list[str]) -> str: | |
| """ | |
| This function takes a list of strings and returns HTML with a bulleted list. | |
| """ | |
| html_items = [] | |
| for item in texts: | |
| html_items.append(f"<li>{item}</li>") | |
| return "<ul>" + "".join(html_items) + "</ul>" | |