HistoPath / histopath /retriever_benchmark.py
ryanDing26
App release
f2a52eb
raw
history blame
3.36 kB
###########################################################################################
# Basic ToolRetriever benchmarking for measuring retrieval rate for a certain custom tool #
# Author: Ryan Ding #
###########################################################################################
import random
from nltk.corpus import wordnet
from histopath.model.retriever import ToolRetriever
from histopath.tool.tool_registry import ToolRegistry
from histopath.utils import read_module2api
from langchain_ollama import ChatOllama
LLM = ChatOllama(model='gpt-oss:120b', temperature=0.7)
PROMPT_v1 = 'Caption the whole slide into patches into directory ./test/directory/'
PROMPT_v2 = 'Caption the whole slide images already segmented into pathces in directory ./test/directory'
RUNS = 100
def synonym_replace(text, p_replace=0.2, protected_words=None):
"""Prompt pertubation via replacement of words with their synoynms.
Parameters
----------
text: str
prompt to perturb
p_replace: float
probability of replacing any given word (default: 0.2)
protected_words: set
words protected from perturbation (default: None)
Returns
-------
str
perturbed prompt
"""
words = text.split()
new_words = []
for w in words:
if protected_words and w in protected_words:
new_words.append(w)
continue
if random.random() < p_replace:
syns = wordnet.synsets(w)
if syns:
lemma_names = syns[0].lemma_names()
if lemma_names:
w = random.choice(lemma_names).replace('_', ' ')
new_words.append(w)
return ' '.join(new_words)
def add_typo(text, p_typo=0.02):
"""Prompt perturbation via integration of character-level typos.
Parameters
----------
text: str
prompt to perturb
p_typo: float
probability of introducing a typo at any given character (default: 0.02)
Returns
-------
str
perturbed prompt
"""
new_text = list(text)
for i in range(len(new_text)):
if random.random() < p_typo:
new_text[i] = random.choice('abcdefghijklmnopqrstuvwxyz')
return ''.join(new_text)
class ToolBenchmark:
def __init__(self, llm, prompts, runs, targets):
self.llm = llm
self.targets = targets
self.prompts = prompts
self.runs = runs
self.module2api = read_module2api()
self.registry = ToolRegistry(self.module2api)
self.retriever = ToolRetriever()
self.all_tools = self.registry.tools
self.resources = { "tools": self.all_tools }
def retrieve_tools(self):
selected_resources = self.retriever.prompt_based_retrieval(query=PROMPT_v2, resources=self.resources, llm=self.llm)
return set([tool for tool in selected_resources["tools"]])
def evaluate(self):
hits = dict()
for _ in range(self.runs):
tools = self.retrieve_tools()
for target in self.targets:
# amount of times proper tool retrieved
if target in tools: hits[target] = hits.get(target, 0) + 1
def main():
pass
if __name__ == '__main__':
main()