Spaces:
Sleeping
Sleeping
| # main class chaining Planner, Worker and Solver. | |
| import re | |
| import time | |
| from nodes.Planner import Planner | |
| from nodes.Solver import Solver | |
| from nodes.Worker import * | |
| from utils.util import * | |
| class PWS: | |
| def __init__(self, available_tools=["Google", "LLM"], fewshot="\n", planner_model="text-davinci-003", | |
| solver_model="text-davinci-003"): | |
| self.workers = available_tools | |
| self.planner = Planner(workers=self.workers, | |
| model_name=planner_model, | |
| fewshot=fewshot) | |
| self.solver = Solver(model_name=solver_model) | |
| self.plans = [] | |
| self.planner_evidences = {} | |
| self.worker_evidences = {} | |
| self.tool_counter = {} | |
| self.planner_token_unit_price = get_token_unit_price(planner_model) | |
| self.solver_token_unit_price = get_token_unit_price(solver_model) | |
| self.tool_token_unit_price = get_token_unit_price("text-davinci-003") | |
| self.google_unit_price = 0.01 | |
| # input: the question line. e.g. "Question: What is the capital of France?" | |
| def run(self, input): | |
| # run is stateless, so we need to reset the evidences | |
| self._reinitialize() | |
| result = {} | |
| st = time.time() | |
| # Plan | |
| planner_response = self.planner.run(input, log=True) | |
| plan = planner_response["output"] | |
| planner_log = planner_response["input"] + planner_response["output"] | |
| self.plans = self._parse_plans(plan) | |
| self.planner_evidences = self._parse_planner_evidences(plan) | |
| #assert len(self.plans) == len(self.planner_evidences) | |
| # Work | |
| self._get_worker_evidences() | |
| worker_log = "" | |
| for i in range(len(self.plans)): | |
| e = f"#E{i + 1}" | |
| worker_log += f"{self.plans[i]}\nEvidence:\n{self.worker_evidences[e]}\n" | |
| # Solve | |
| solver_response = self.solver.run(input, worker_log, log=True) | |
| output = solver_response["output"] | |
| solver_log = solver_response["input"] + solver_response["output"] | |
| result["wall_time"] = time.time() - st | |
| result["input"] = input | |
| result["output"] = output | |
| result["planner_log"] = planner_log | |
| result["worker_log"] = worker_log | |
| result["solver_log"] = solver_log | |
| result["tool_usage"] = self.tool_counter | |
| result["steps"] = len(self.plans) + 1 | |
| result["total_tokens"] = planner_response["prompt_tokens"] + planner_response["completion_tokens"] \ | |
| + solver_response["prompt_tokens"] + solver_response["completion_tokens"] \ | |
| + self.tool_counter.get("LLM_token", 0) \ | |
| + self.tool_counter.get("Calculator_token", 0) | |
| result["token_cost"] = self.planner_token_unit_price * (planner_response["prompt_tokens"] + planner_response["completion_tokens"]) \ | |
| + self.solver_token_unit_price * (solver_response["prompt_tokens"] + solver_response["completion_tokens"]) \ | |
| + self.tool_token_unit_price * (self.tool_counter.get("LLM_token", 0) + self.tool_counter.get("Calculator_token", 0)) | |
| result["tool_cost"] = self.tool_counter.get("Google", 0) * self.google_unit_price | |
| result["total_cost"] = result["token_cost"] + result["tool_cost"] | |
| return result | |
| def _parse_plans(self, response): | |
| plans = [] | |
| for line in response.splitlines(): | |
| if line.startswith("Plan:"): | |
| plans.append(line) | |
| return plans | |
| def _parse_planner_evidences(self, response): | |
| evidences = {} | |
| for line in response.splitlines(): | |
| if line.startswith("#") and line[1] == "E" and line[2].isdigit(): | |
| e, tool_call = line.split("=", 1) | |
| e, tool_call = e.strip(), tool_call.strip() | |
| if len(e) == 3: | |
| evidences[e] = tool_call | |
| else: | |
| evidences[e] = "No evidence found" | |
| return evidences | |
| # use planner evidences to assign tasks to respective workers. | |
| def _get_worker_evidences(self): | |
| for e, tool_call in self.planner_evidences.items(): | |
| if "[" not in tool_call: | |
| self.worker_evidences[e] = tool_call | |
| continue | |
| tool, tool_input = tool_call.split("[", 1) | |
| tool_input = tool_input[:-1] | |
| # find variables in input and replace with previous evidences | |
| for var in re.findall(r"#E\d+", tool_input): | |
| if var in self.worker_evidences: | |
| tool_input = tool_input.replace(var, "[" + self.worker_evidences[var] + "]") | |
| if tool in self.workers: | |
| self.worker_evidences[e] = WORKER_REGISTRY[tool].run(tool_input) | |
| if tool == "Google": | |
| self.tool_counter["Google"] = self.tool_counter.get("Google", 0) + 1 # number of query | |
| elif tool == "LLM": | |
| self.tool_counter["LLM_token"] = self.tool_counter.get("LLM_token", 0) + len( | |
| tool_input + self.worker_evidences[e]) // 4 | |
| elif tool == "Calculator": | |
| self.tool_counter["Calculator_token"] = self.tool_counter.get("Calculator_token", 0) \ | |
| + len( | |
| LLMMathChain(llm=OpenAI(), verbose=False).prompt.template + tool_input + self.worker_evidences[ | |
| e]) // 4 | |
| else: | |
| self.worker_evidences[e] = "No evidence found" | |
| def _reinitialize(self): | |
| self.plans = [] | |
| self.planner_evidences = {} | |
| self.worker_evidences = {} | |
| self.tool_counter = {} | |
| class PWS_Base(PWS): | |
| def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_BASE, planner_model="text-davinci-003", | |
| solver_model="text-davinci-003", available_tools=["Wikipedia", "LLM"]): | |
| super().__init__(available_tools=available_tools, | |
| fewshot=fewshot, | |
| planner_model=planner_model, | |
| solver_model=solver_model) | |
| class PWS_Extra(PWS): | |
| def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_EXTRA, planner_model="text-davinci-003", | |
| solver_model="text-davinci-003", available_tools=["Google", "Calculator", "LLM"]): | |
| super().__init__(available_tools=available_tools, | |
| fewshot=fewshot, | |
| planner_model=planner_model, | |
| solver_model=solver_model) | |