# Copyright 2025 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 # Modified from Dream repos: https://github.com/HKUNLP/Dream """Post-processing LLM-generated Python code implemented using tree-sitter.""" import os import sys import pathlib ROOT = os.path.dirname(os.path.abspath(__file__)) sys.path.extend([os.path.dirname(ROOT), os.path.dirname(os.path.dirname(ROOT))]) import ast import traceback from typing import Dict, List, Optional, Set, Tuple def refine_text(text: str) -> str: text = text.replace("\t", " ") text = text.replace("\r\n", "\n").replace("\r", "\n") return text.strip() + "\n" def syntax_check(code, verbose = False): try: ast.parse(code) return True except (SyntaxError, MemoryError): if verbose: traceback.print_exc() return False def extract_longest_valid_code(text: str) -> str: lines = text.splitlines() if len(lines) > 100: lines = lines[:100] max_valid_lines = 0 max_valid_snippet = "" for i in range(len(lines)): for j in range(i, len(lines)): current_snippet = "\n".join(lines[i:j+1]) if syntax_check(current_snippet): valid_line_count = sum(1 for line in lines[i:j+1] if line.strip()) if valid_line_count > max_valid_lines: max_valid_lines = valid_line_count max_valid_snippet = current_snippet return max_valid_snippet def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]: name2deps = {} for name, node in nodes: deps = set() stack = [node] while stack: current = stack.pop() for child in ast.iter_child_nodes(current): if isinstance(child, ast.Name): deps.add(child.id) elif isinstance(child, ast.Attribute): deps.add(child.attr) else: stack.append(child) name2deps[name] = deps return name2deps def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]: visited = set() to_visit = [entrypoint] while to_visit: current = to_visit.pop(0) if current not in visited: visited.add(current) to_visit.extend(call_graph.get(current, set()) - visited) return visited def get_definition_name(node: ast.AST) -> Optional[str]: if isinstance(node, (ast.FunctionDef, ast.ClassDef)): return node.name elif isinstance(node, ast.Assign): targets = node.targets if targets and isinstance(targets[0], ast.Name): return targets[0].id return None def has_return_statement(node: ast.AST) -> bool: return any(isinstance(n, ast.Return) for n in ast.walk(node)) def sanitize(text: str, entrypoint: Optional[str] = None) -> str: text = refine_text(text) # text = python_extract(text) code = extract_longest_valid_code(text) tree = ast.parse(code) definitions = {} imports = [] for node in tree.body: if isinstance(node, (ast.Import, ast.ImportFrom)): imports.append(node) elif isinstance(node, ast.ClassDef): name = node.name definitions[name] = ('class', node) elif isinstance(node, ast.FunctionDef): name = node.name if has_return_statement(node): definitions[name] = ('function', node) elif isinstance(node, ast.Assign): name = get_definition_name(node) if name: definitions[name] = ('variable', node) if entrypoint: name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()]) reachable = get_function_dependency(entrypoint, name2deps) sanitized_output = [] for node in imports: sanitized_output.append(ast.unparse(node)) for name, (_, node) in definitions.items(): if not entrypoint or name in reachable: sanitized_output.append(ast.unparse(node)) return "\n".join(sanitized_output)