Spaces:
Running
Running
| from __future__ import annotations | |
| from typing import Union | |
| import libcst as cst | |
| from libcst._nodes.module import Module | |
| DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef] | |
| def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine: | |
| """Extracts the docstring from the body of a node. | |
| Args: | |
| body: The body of a node. | |
| Returns: | |
| The docstring statement if it exists, None otherwise. | |
| """ | |
| if isinstance(body, cst.Module): | |
| body = body.body | |
| else: | |
| body = body.body.body | |
| if not body: | |
| return | |
| statement = body[0] | |
| if not isinstance(statement, cst.SimpleStatementLine): | |
| return | |
| expr = statement | |
| while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)): | |
| if len(expr.body) == 0: | |
| return None | |
| expr = expr.body[0] | |
| if not isinstance(expr, cst.Expr): | |
| return None | |
| val = expr.value | |
| if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)): | |
| return None | |
| evaluated_value = val.evaluated_value | |
| if isinstance(evaluated_value, bytes): | |
| return None | |
| return statement | |
| def has_decorator(node: DocstringNode, name: str) -> bool: | |
| return hasattr(node, "decorators") and any( | |
| (hasattr(i.decorator, "value") and i.decorator.value == name) | |
| or (hasattr(i.decorator, "func") and hasattr(i.decorator.func, "value") and i.decorator.func.value == name) | |
| for i in node.decorators | |
| ) | |
| class DocstringCollector(cst.CSTVisitor): | |
| """A visitor class for collecting docstrings from a CST. | |
| Attributes: | |
| stack: A list to keep track of the current path in the CST. | |
| docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. | |
| """ | |
| def __init__(self): | |
| self.stack: list[str] = [] | |
| self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {} | |
| def visit_Module(self, node: cst.Module) -> bool | None: | |
| self.stack.append("") | |
| def leave_Module(self, node: cst.Module) -> None: | |
| return self._leave(node) | |
| def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: | |
| self.stack.append(node.name.value) | |
| def leave_ClassDef(self, node: cst.ClassDef) -> None: | |
| return self._leave(node) | |
| def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: | |
| self.stack.append(node.name.value) | |
| def leave_FunctionDef(self, node: cst.FunctionDef) -> None: | |
| return self._leave(node) | |
| def _leave(self, node: DocstringNode) -> None: | |
| key = tuple(self.stack) | |
| self.stack.pop() | |
| if has_decorator(node, "overload"): | |
| return | |
| statement = get_docstring_statement(node) | |
| if statement: | |
| self.docstrings[key] = statement | |
| class DocstringTransformer(cst.CSTTransformer): | |
| """A transformer class for replacing docstrings in a CST. | |
| Attributes: | |
| stack: A list to keep track of the current path in the CST. | |
| docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. | |
| """ | |
| def __init__( | |
| self, | |
| docstrings: dict[tuple[str, ...], cst.SimpleStatementLine], | |
| ): | |
| self.stack: list[str] = [] | |
| self.docstrings = docstrings | |
| def visit_Module(self, node: cst.Module) -> bool | None: | |
| self.stack.append("") | |
| def leave_Module(self, original_node: Module, updated_node: Module) -> Module: | |
| return self._leave(original_node, updated_node) | |
| def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: | |
| self.stack.append(node.name.value) | |
| def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode: | |
| return self._leave(original_node, updated_node) | |
| def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: | |
| self.stack.append(node.name.value) | |
| def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode: | |
| return self._leave(original_node, updated_node) | |
| def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode: | |
| key = tuple(self.stack) | |
| self.stack.pop() | |
| if has_decorator(updated_node, "overload"): | |
| return updated_node | |
| statement = self.docstrings.get(key) | |
| if not statement: | |
| return updated_node | |
| original_statement = get_docstring_statement(original_node) | |
| if isinstance(updated_node, cst.Module): | |
| body = updated_node.body | |
| if original_statement: | |
| return updated_node.with_changes(body=(statement, *body[1:])) | |
| else: | |
| updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body)) | |
| return updated_node | |
| body = updated_node.body.body[1:] if original_statement else updated_node.body.body | |
| return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body))) | |
| def merge_docstring(code: str, documented_code: str) -> str: | |
| """Merges the docstrings from the documented code into the original code. | |
| Args: | |
| code: The original code. | |
| documented_code: The documented code. | |
| Returns: | |
| The original code with the docstrings from the documented code. | |
| """ | |
| code_tree = cst.parse_module(code) | |
| documented_code_tree = cst.parse_module(documented_code) | |
| visitor = DocstringCollector() | |
| documented_code_tree.visit(visitor) | |
| transformer = DocstringTransformer(visitor.docstrings) | |
| modified_tree = code_tree.visit(transformer) | |
| return modified_tree.code | |