########################################################################################### # Basic ToolRetriever benchmarking for measuring retrieval rate for a certain custom tool # # Author: Ryan Ding # ########################################################################################### import random from nltk.corpus import wordnet from histopath.model.retriever import ToolRetriever from histopath.tool.tool_registry import ToolRegistry from histopath.utils import read_module2api from langchain_ollama import ChatOllama LLM = ChatOllama(model='gpt-oss:120b', temperature=0.7) PROMPT_v1 = 'Caption the whole slide into patches into directory ./test/directory/' PROMPT_v2 = 'Caption the whole slide images already segmented into pathces in directory ./test/directory' RUNS = 100 def synonym_replace(text, p_replace=0.2, protected_words=None): """Prompt pertubation via replacement of words with their synoynms. Parameters ---------- text: str prompt to perturb p_replace: float probability of replacing any given word (default: 0.2) protected_words: set words protected from perturbation (default: None) Returns ------- str perturbed prompt """ words = text.split() new_words = [] for w in words: if protected_words and w in protected_words: new_words.append(w) continue if random.random() < p_replace: syns = wordnet.synsets(w) if syns: lemma_names = syns[0].lemma_names() if lemma_names: w = random.choice(lemma_names).replace('_', ' ') new_words.append(w) return ' '.join(new_words) def add_typo(text, p_typo=0.02): """Prompt perturbation via integration of character-level typos. Parameters ---------- text: str prompt to perturb p_typo: float probability of introducing a typo at any given character (default: 0.02) Returns ------- str perturbed prompt """ new_text = list(text) for i in range(len(new_text)): if random.random() < p_typo: new_text[i] = random.choice('abcdefghijklmnopqrstuvwxyz') return ''.join(new_text) class ToolBenchmark: def __init__(self, llm, prompts, runs, targets): self.llm = llm self.targets = targets self.prompts = prompts self.runs = runs self.module2api = read_module2api() self.registry = ToolRegistry(self.module2api) self.retriever = ToolRetriever() self.all_tools = self.registry.tools self.resources = { "tools": self.all_tools } def retrieve_tools(self): selected_resources = self.retriever.prompt_based_retrieval(query=PROMPT_v2, resources=self.resources, llm=self.llm) return set([tool for tool in selected_resources["tools"]]) def evaluate(self): hits = dict() for _ in range(self.runs): tools = self.retrieve_tools() for target in self.targets: # amount of times proper tool retrieved if target in tools: hits[target] = hits.get(target, 0) + 1 def main(): pass if __name__ == '__main__': main()