File size: 3,359 Bytes
f2a52eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
########################################################################################### 
# Basic ToolRetriever benchmarking for measuring retrieval rate for a certain custom tool #
# Author: Ryan Ding                                                                       #
###########################################################################################
import random
from nltk.corpus import wordnet
from histopath.model.retriever import ToolRetriever
from histopath.tool.tool_registry import ToolRegistry
from histopath.utils import read_module2api
from langchain_ollama import ChatOllama
        
LLM = ChatOllama(model='gpt-oss:120b', temperature=0.7)
PROMPT_v1 = 'Caption the whole slide into patches into directory ./test/directory/'
PROMPT_v2 = 'Caption the whole slide images already segmented into pathces in directory ./test/directory'
RUNS = 100

def synonym_replace(text, p_replace=0.2, protected_words=None):
    """Prompt pertubation via replacement of words with their synoynms.
    
    Parameters
    ----------
    text: str
        prompt to perturb
    p_replace: float
        probability of replacing any given word (default: 0.2)
    protected_words: set
        words protected from perturbation (default: None)

    Returns
    -------
    str
        perturbed prompt
    """
    words = text.split()
    new_words = []
    for w in words:
        if protected_words and w in protected_words:
            new_words.append(w)
            continue
        if random.random() < p_replace:
            syns = wordnet.synsets(w)
            if syns:
                lemma_names = syns[0].lemma_names()
                if lemma_names:
                    w = random.choice(lemma_names).replace('_', ' ')
        new_words.append(w)
    return ' '.join(new_words)


def add_typo(text, p_typo=0.02):
    """Prompt perturbation via integration of character-level typos.

    Parameters
    ----------
    text: str
        prompt to perturb
    p_typo: float
        probability of introducing a typo at any given character (default: 0.02)
    
    Returns
    -------
    str
        perturbed prompt
    
    """
    new_text = list(text)
    for i in range(len(new_text)):
        if random.random() < p_typo:
            new_text[i] = random.choice('abcdefghijklmnopqrstuvwxyz')
    return ''.join(new_text)

class ToolBenchmark:
    def __init__(self, llm, prompts, runs, targets):
        self.llm = llm
        self.targets = targets
        self.prompts = prompts
        self.runs = runs
        self.module2api = read_module2api()
        self.registry = ToolRegistry(self.module2api)
        self.retriever = ToolRetriever()
        self.all_tools = self.registry.tools
        self.resources = { "tools": self.all_tools }

    def retrieve_tools(self):
        selected_resources = self.retriever.prompt_based_retrieval(query=PROMPT_v2, resources=self.resources, llm=self.llm)
        return set([tool for tool in selected_resources["tools"]])
    
    def evaluate(self):
        hits = dict()
        for _ in range(self.runs):
            tools = self.retrieve_tools()
            for target in self.targets:
                # amount of times proper tool retrieved
                if target in tools: hits[target] = hits.get(target, 0) + 1

            
def main():
    pass

if __name__ == '__main__':
    main()