|
|
from langchain.docstore.document import Document |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from smolagents import Tool |
|
|
from langchain_community.retrievers import BM25Retriever |
|
|
from smolagents import CodeAgent, InferenceClientModel |
|
|
from datasets import load_dataset |
|
|
import re |
|
|
import pandas as pd |
|
|
|
|
|
class QuestionRetrieverTool(Tool): |
|
|
name = "Question_retriever" |
|
|
description = "Uses semantic search to retrieve relevant question given the class, difficulty, and topic inputs by the user." |
|
|
inputs = { |
|
|
"query": { |
|
|
"type": "string", |
|
|
"description": "This tool returns relevant question and answer pairs based on the provided context.", |
|
|
} |
|
|
} |
|
|
output_type = "string" |
|
|
|
|
|
def __init__(self, docs, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.retriever = BM25Retriever.from_documents( |
|
|
docs, k=5 |
|
|
) |
|
|
|
|
|
def forward(self, query: str) -> str: |
|
|
assert isinstance(query, str), "Your search query must be a string" |
|
|
|
|
|
docs = self.retriever.invoke( |
|
|
query, |
|
|
) |
|
|
return "\nRetrieved example question and answer pairs:\n" + "".join( |
|
|
[ |
|
|
f"\n\n===== Q and A pairs {str(i)} =====\n" + doc.page_content |
|
|
for i, doc in enumerate(docs) |
|
|
] |
|
|
) |
|
|
|