Spaces:
Running
Running
| from utils import enumerate_resume, make_printv, write_jsonl, resume_success_count | |
| from executors import executor_factory | |
| from generators import generator_factory, model_factory | |
| from typing import List, Dict, Any | |
| import math | |
| from typing import Tuple | |
| import sys | |
| import random | |
| sys.set_int_max_str_digits(100000) # Increase the limit to 10000 digits | |
| react_prompt_header = "Here are some previous solutions and the corresponding test results.\n" | |
| react_prompt_starter = "\n\nYour solution:\n" | |
| extra_header = "\n\nName the function answer()" | |
| class Node: | |
| def __init__(self, solution: str, parent=None, context="", depth=0): | |
| self.solution = solution | |
| self.parent = parent | |
| self.children = [] | |
| self.value = 0 | |
| self.visits = 0 | |
| self.context = "" | |
| self.depth = depth | |
| self.reflection = "" | |
| self.test_feedback = "" | |
| def uct(self, exploration_weight=1.0): | |
| if self.visits == 0: | |
| #return float('inf') | |
| return self.value | |
| return (self.value / self.visits) + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits) | |
| def best_child(self): | |
| if not self.children: # Check if children list is empty | |
| return None | |
| return max(self.children, key=lambda child: child.uct()) | |
| def best_child_value(self): | |
| if not self.children: # Check if children list is empty | |
| return None | |
| return max(self.children, key=lambda child: child.value) | |
| def update(self, reward: float): | |
| self.visits += 1 | |
| self.value += reward | |
| def prune_context_blocks(context: str, max_length: int) -> str: | |
| """Prune the context to fit within the specified max_length by removing entire blocks of content using 'trial' as a delimiter.""" | |
| if len(context) <= max_length: | |
| return context | |
| # Split by the block delimiter "trial". | |
| blocks = context.split('Previous Trial') | |
| # Remove the earliest blocks until the context fits within max_length. | |
| while len('trial'.join(blocks)) > max_length and blocks: | |
| blocks.pop(0) | |
| return 'trial'.join(blocks) | |
| def gather_context_from_tree(node: Node) -> Tuple[List[str], List[str]]: | |
| """ | |
| Given a node, walk up its tree and gather the feedback and reflections | |
| from each parent node until the root is reached. | |
| Args: | |
| node (Node): The node to start gathering context from. | |
| Returns: | |
| Tuple[List[str], List[str]]: Two lists containing the accumulated feedback and reflections. | |
| """ | |
| accumulated_feedback = [] | |
| accumulated_reflection = [] | |
| num_nodes = 0 | |
| while node and num_nodes < 2: | |
| num_nodes += 1 | |
| if node.test_feedback: | |
| accumulated_feedback.append(node.test_feedback) | |
| if node.reflection: | |
| accumulated_reflection.append(node.reflection) | |
| node = node.parent | |
| # Reverse the lists so that the context from the earliest nodes is first | |
| return accumulated_feedback[::-1], accumulated_reflection[::-1] | |
| def sample_n_random(items: List[str], n: int) -> List[str]: | |
| """Sample min(n, len(items)) random items from a list""" | |
| assert n >= 0 | |
| if n >= len(items): | |
| return items | |
| return random.sample(items, n) | |
| def run_lats( | |
| model_name: str, | |
| language: str, | |
| max_iters: int, | |
| verbose: bool, | |
| instruction: str = "Write some code to print Hello World in Python", | |
| n_samples: int = 3, | |
| depth: int = 5, | |
| ) -> None: | |
| exe = executor_factory(language) | |
| gen = generator_factory(language) | |
| model = model_factory(model_name) | |
| num_success = 0 # Counter for successful solutions | |
| cur_func_impl = None | |
| item = {} | |
| #for idx, item in enumerate(dataset): | |
| tests = gen.internal_tests(instruction + extra_header, model, 1) | |
| tests_i = sample_n_random(tests, 1) | |
| while cur_func_impl is None: | |
| cur_func_impl = gen.func_impl(instruction + extra_header, model, "simple") | |
| root = Node(cur_func_impl) # initial solution (for pass@1 metric) | |
| # Lists for logging | |
| reflections = [] | |
| implementations = [] | |
| test_feedback = [] | |
| is_solved = False | |
| # first attempt | |
| implementations.append(cur_func_impl) | |
| assert isinstance(cur_func_impl, str) | |
| is_passing, feedback, _ = exe.execute(cur_func_impl, tests_i) | |
| test_feedback.append(feedback) | |
| # if solved, exit early | |
| if is_passing: | |
| num_success += 1 | |
| return cur_func_impl # GET SOLUTION | |
| reflection = gen.self_reflection(cur_func_impl, feedback, model) | |
| reflections += [reflection] | |
| root.test_feedback = feedback | |
| root.reflection = reflection | |
| max_iters = int(max_iters) | |
| for cur_iter in range(max_iters): | |
| # Selection | |
| tests_i = sample_n_random(tests, 1) | |
| node = root | |
| trajectory = { | |
| 'solutions': [], | |
| 'feedbacks': [] | |
| } | |
| while node.children: | |
| node = node.best_child() | |
| trajectory['solutions'].append(node.solution) | |
| # Expansion | |
| for _ in range(n_samples): | |
| new_solution = None | |
| strategy = "mcts" | |
| prev_func_impl = node.solution | |
| feedback = node.test_feedback | |
| reflection = node.reflection | |
| acc_feedback, acc_reflection = gather_context_from_tree(node) | |
| while new_solution is None: | |
| new_solution = gen.func_impl( | |
| func_sig=instruction+extra_header, | |
| model=model, | |
| strategy=strategy, | |
| prev_func_impl=prev_func_impl, | |
| feedback=feedback, | |
| self_reflection=reflection, | |
| acc_feedback = acc_feedback, | |
| acc_reflection = acc_reflection | |
| ) | |
| combined_context = "\nPrevious Trial\n\n" + new_solution | |
| child = Node(new_solution, parent=node, context=combined_context, depth=node.depth + 1) | |
| node.children.append(child) | |
| # Simulation | |
| reward_real = 0 | |
| for child in node.children: | |
| is_passing_internal, feedback_internal, _ = exe.execute(child.solution, tests_i) | |
| if not is_passing_internal: | |
| reflection = gen.self_reflection(child.solution, feedback_internal, model) | |
| reflections.append(reflection) | |
| child.reflection = reflection | |
| child.test_feedback = feedback_internal | |
| child.context += "\n\nPrevious Trial\n\n" + child.solution + "\n\nTest results: \n" + feedback_internal + "\n\nSelf-reflection: " + reflection | |
| else: | |
| child.context += "\n\nPrevious Trial\n\n" + child.solution + "\n\nTest results: \n" + feedback_internal | |
| child.reflection = "" | |
| child.test_feedback = feedback_internal | |
| if "Tested passed:" in feedback_internal: | |
| # Split at "Tests failed:" and get the part before it (which contains the passed tests) | |
| passed_section = feedback_internal.split("Tests failed:")[0] | |
| # Split at "Tested passed:" and get the part after it, then count the non-empty lines | |
| reward_internal = len([line for line in passed_section.split("Tested passed:")[1].splitlines() if line.strip() != '']) | |
| reward_internal = reward_internal / len(tests_i) | |
| else: | |
| reward_internal = 0 | |
| if is_passing_internal or cur_iter == max_iters - 1: | |
| item["solution"] = child.solution | |
| break | |
| if is_solved: | |
| break | |
| reward = reward_internal + reward_real | |
| child.update(reward) | |
| # Backpropagation | |
| temp = child | |
| while temp.parent: | |
| temp = temp.parent | |
| temp.update(reward) | |
| # Choose the best solution after all iterations | |
| if is_solved: | |
| best_solution = item["solution"] | |
| else: | |
| best_solution = root.best_child_value().solution | |
| item["solution"] = best_solution | |
| return best_solution |