|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
|
from nltk.corpus import wordnet |
|
|
from histopath.model.retriever import ToolRetriever |
|
|
from histopath.tool.tool_registry import ToolRegistry |
|
|
from histopath.utils import read_module2api |
|
|
from langchain_ollama import ChatOllama |
|
|
|
|
|
LLM = ChatOllama(model='gpt-oss:120b', temperature=0.7) |
|
|
PROMPT_v1 = 'Caption the whole slide into patches into directory ./test/directory/' |
|
|
PROMPT_v2 = 'Caption the whole slide images already segmented into pathces in directory ./test/directory' |
|
|
RUNS = 100 |
|
|
|
|
|
def synonym_replace(text, p_replace=0.2, protected_words=None): |
|
|
"""Prompt pertubation via replacement of words with their synoynms. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
text: str |
|
|
prompt to perturb |
|
|
p_replace: float |
|
|
probability of replacing any given word (default: 0.2) |
|
|
protected_words: set |
|
|
words protected from perturbation (default: None) |
|
|
|
|
|
Returns |
|
|
------- |
|
|
str |
|
|
perturbed prompt |
|
|
""" |
|
|
words = text.split() |
|
|
new_words = [] |
|
|
for w in words: |
|
|
if protected_words and w in protected_words: |
|
|
new_words.append(w) |
|
|
continue |
|
|
if random.random() < p_replace: |
|
|
syns = wordnet.synsets(w) |
|
|
if syns: |
|
|
lemma_names = syns[0].lemma_names() |
|
|
if lemma_names: |
|
|
w = random.choice(lemma_names).replace('_', ' ') |
|
|
new_words.append(w) |
|
|
return ' '.join(new_words) |
|
|
|
|
|
|
|
|
def add_typo(text, p_typo=0.02): |
|
|
"""Prompt perturbation via integration of character-level typos. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
text: str |
|
|
prompt to perturb |
|
|
p_typo: float |
|
|
probability of introducing a typo at any given character (default: 0.02) |
|
|
|
|
|
Returns |
|
|
------- |
|
|
str |
|
|
perturbed prompt |
|
|
|
|
|
""" |
|
|
new_text = list(text) |
|
|
for i in range(len(new_text)): |
|
|
if random.random() < p_typo: |
|
|
new_text[i] = random.choice('abcdefghijklmnopqrstuvwxyz') |
|
|
return ''.join(new_text) |
|
|
|
|
|
class ToolBenchmark: |
|
|
def __init__(self, llm, prompts, runs, targets): |
|
|
self.llm = llm |
|
|
self.targets = targets |
|
|
self.prompts = prompts |
|
|
self.runs = runs |
|
|
self.module2api = read_module2api() |
|
|
self.registry = ToolRegistry(self.module2api) |
|
|
self.retriever = ToolRetriever() |
|
|
self.all_tools = self.registry.tools |
|
|
self.resources = { "tools": self.all_tools } |
|
|
|
|
|
def retrieve_tools(self): |
|
|
selected_resources = self.retriever.prompt_based_retrieval(query=PROMPT_v2, resources=self.resources, llm=self.llm) |
|
|
return set([tool for tool in selected_resources["tools"]]) |
|
|
|
|
|
def evaluate(self): |
|
|
hits = dict() |
|
|
for _ in range(self.runs): |
|
|
tools = self.retrieve_tools() |
|
|
for target in self.targets: |
|
|
|
|
|
if target in tools: hits[target] = hits.get(target, 0) + 1 |
|
|
|
|
|
|
|
|
def main(): |
|
|
pass |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |