|
|
import re |
|
|
import contextlib |
|
|
from langchain_core.messages import HumanMessage |
|
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
class ToolRetriever: |
|
|
"""Retrieve tools from the tool registry.""" |
|
|
|
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
def prompt_based_retrieval(self, query: str, resources: dict, llm=None) -> dict: |
|
|
"""Use a prompt-based approach to retrieve the most relevant resources for a query. |
|
|
|
|
|
Args: |
|
|
query: The user's query |
|
|
resources: A dictionary with keys 'tools', 'data_lake', and 'libraries', |
|
|
each containing a list of available resources |
|
|
llm: Optional LLM instance to use for retrieval (if None, will create a new one) |
|
|
|
|
|
Returns: |
|
|
A dictionary with the same keys, but containing only the most relevant resources |
|
|
|
|
|
""" |
|
|
|
|
|
prompt = f""" |
|
|
You are an expert histopathology research assistant. Your task is to select the relevant resources to help answer a user's query. |
|
|
|
|
|
USER QUERY: {query} |
|
|
|
|
|
Below are the available resources. For each category, select items that are directly or indirectly relevant to answering the query. |
|
|
Be generous in your selection - include resources that might be useful for the task, even if they're not explicitly mentioned in the query. |
|
|
It's better to include slightly more resources than to miss potentially useful ones. |
|
|
|
|
|
AVAILABLE SOFTWARE LIBRARIES: |
|
|
{self._format_resources_for_prompt(resources.get("libraries", []))} |
|
|
|
|
|
AVAILABLE TOOLS: |
|
|
{self._format_resources_for_prompt(resources.get("tools", []))} |
|
|
|
|
|
For each category, respond with ONLY the indices of the relevant items in the following format: |
|
|
TOOLS: [list of indices] |
|
|
|
|
|
For example: |
|
|
TOOLS: [0, 3, 5, 7, 9] |
|
|
|
|
|
If a category has no relevant items, use an empty list, e.g., TOOLS: [] |
|
|
|
|
|
IMPORTANT GUIDELINES: |
|
|
1. Be generous but not excessive - aim to include all potentially relevant resources |
|
|
2. ALWAYS prioritize database tools for general queries - include as many database tools as possible |
|
|
3. Include all literature search tools |
|
|
4. For libraries, include those that provide functions needed for analysis |
|
|
5. Don't exclude resources just because they're not explicitly mentioned in the query |
|
|
6. When in doubt about a tool, include it rather than exclude it |
|
|
""" |
|
|
|
|
|
|
|
|
if llm is None: |
|
|
llm = ChatOpenAI(model="gpt-4o") |
|
|
|
|
|
|
|
|
if hasattr(llm, "invoke"): |
|
|
|
|
|
response = llm.invoke([HumanMessage(content=prompt)]) |
|
|
response_content = response.content |
|
|
else: |
|
|
|
|
|
response_content = str(llm(prompt)) |
|
|
|
|
|
|
|
|
selected_indices = self._parse_llm_response(response_content) |
|
|
|
|
|
|
|
|
selected_resources = { |
|
|
"tools": [ |
|
|
resources["tools"][i] |
|
|
for i in selected_indices.get("tools", []) |
|
|
if i < len(resources.get("tools", [])) |
|
|
], |
|
|
"libraries": [ |
|
|
resources["libraries"][i] |
|
|
for i in selected_indices.get("libraries", []) |
|
|
if i < len(resources.get("libraries", [])) |
|
|
] |
|
|
} |
|
|
|
|
|
return selected_resources |
|
|
|
|
|
def _format_resources_for_prompt(self, resources: list) -> str: |
|
|
"""Format resources for inclusion in the prompt.""" |
|
|
formatted = [] |
|
|
for i, resource in enumerate(resources): |
|
|
if isinstance(resource, dict): |
|
|
|
|
|
name = resource.get("name", f"Resource {i}") |
|
|
description = resource.get("description", "") |
|
|
formatted.append(f"{i}. {name}: {description}") |
|
|
elif isinstance(resource, str): |
|
|
|
|
|
formatted.append(f"{i}. {resource}") |
|
|
else: |
|
|
|
|
|
name = getattr(resource, "name", str(resource)) |
|
|
desc = getattr(resource, "description", "") |
|
|
formatted.append(f"{i}. {name}: {desc}") |
|
|
|
|
|
return "\n".join(formatted) if formatted else "None available" |
|
|
|
|
|
def _parse_llm_response(self, response: str) -> dict: |
|
|
"""Parse the LLM response to extract the selected indices.""" |
|
|
selected_indices = {"tools": [], "libraries": []} |
|
|
|
|
|
|
|
|
tools_match = re.search(r"TOOLS:\s*\[(.*?)\]", response, re.IGNORECASE) |
|
|
if tools_match and tools_match.group(1).strip(): |
|
|
with contextlib.suppress(ValueError): |
|
|
selected_indices["tools"] = [int(idx.strip()) for idx in tools_match.group(1).split(",") if idx.strip()] |
|
|
|
|
|
libraries_match = re.search(r"LIBRARIES:\s*\[(.*?)\]", response, re.IGNORECASE) |
|
|
if libraries_match and libraries_match.group(1).strip(): |
|
|
with contextlib.suppress(ValueError): |
|
|
selected_indices["libraries"] = [ |
|
|
int(idx.strip()) for idx in libraries_match.group(1).split(",") if idx.strip() |
|
|
] |
|
|
|
|
|
return selected_indices |