Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-38870: Docstring support for function/class/module nodes #17760

Merged
merged 12 commits into from
Mar 2, 2020
52 changes: 42 additions & 10 deletions Lib/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,20 @@ def set_precedence(self, precedence, *nodes):
for node in nodes:
self._precedences[node] = precedence

def get_raw_docstring(self, node):
"""If a docstring node is found in the body of the *node* parameter, return
said docstring node, None otherwise."""
if not isinstance(
node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)
) or len(node.body) < 1:
return None
node = node.body[0]
if not isinstance(node, Expr):
return None
node = node.value
if isinstance(node, Constant) and isinstance(node.value, str):
return node

def traverse(self, node):
if isinstance(node, list):
for item in node:
Expand All @@ -681,9 +695,15 @@ def visit(self, node):
self.traverse(node)
return "".join(self._source)

def _body_helper(self, node):
isidentical marked this conversation as resolved.
Show resolved Hide resolved
if (docstring := self.get_raw_docstring(node)):
self._write_docstring(docstring)
self.traverse(node.body[1:])
else:
self.traverse(node.body)

def visit_Module(self, node):
for subnode in node.body:
self.traverse(subnode)
self._body_helper(node)

def visit_Expr(self, node):
self.fill()
Expand Down Expand Up @@ -850,15 +870,15 @@ def visit_ClassDef(self, node):
self.traverse(e)

with self.block():
self.traverse(node.body)
self._body_helper(node)

def visit_FunctionDef(self, node):
self.__FunctionDef_helper(node, "def")
self._function_helper(node, "def")

def visit_AsyncFunctionDef(self, node):
self.__FunctionDef_helper(node, "async def")
self._function_helper(node, "async def")

def __FunctionDef_helper(self, node, fill_suffix):
def _function_helper(self, node, fill_suffix):
self.write("\n")
for deco in node.decorator_list:
self.fill("@")
Expand All @@ -871,15 +891,15 @@ def __FunctionDef_helper(self, node, fill_suffix):
self.write(" -> ")
self.traverse(node.returns)
with self.block():
self.traverse(node.body)
self._body_helper(node)

def visit_For(self, node):
self.__For_helper("for ", node)
self._for_helper("for ", node)

def visit_AsyncFor(self, node):
self.__For_helper("async for ", node)
self._for_helper("async for ", node)

def __For_helper(self, fill, node):
def _for_helper(self, fill, node):
self.fill(fill)
self.traverse(node.target)
self.write(" in ")
Expand Down Expand Up @@ -974,6 +994,18 @@ def _fstring_FormattedValue(self, node, write):
def visit_Name(self, node):
self.write(node.id)

def _write_docstring(self, node):
self.fill()
if node.kind == "u":
self.write("u")

value = node.value.replace("\\", "\\\\")
value = value.replace('"""', '""\"')
isidentical marked this conversation as resolved.
Show resolved Hide resolved
if value[-1] == '"':
value = value.replace('"', '\\"', -1)

self.write(f'"""{value}"""')

def _write_constant(self, value):
if isinstance(value, (float, complex)):
# Substitute overflowing decimal literal for AST infinities.
Expand Down
27 changes: 27 additions & 0 deletions Lib/test/test_unparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,16 @@ def test_invalid_set(self):
def test_invalid_yield_from(self):
self.check_invalid(ast.YieldFrom(value=None))

def test_docstrings(self):
docstrings = (
'this ends with double quote"',
'this includes a """triple quote"""'
)
for docstring in docstrings:
# check as Module docstrings for easy testing
self.check_roundtrip(f"'{docstring}'")


class CosmeticTestCase(ASTTestCase):
"""Test if there are cosmetic issues caused by unnecesary additions"""

Expand Down Expand Up @@ -321,6 +331,23 @@ def test_simple_expressions_parens(self):
self.check_src_roundtrip("call((yield x))")
self.check_src_roundtrip("return x + (yield x)")

def test_docstrings(self):
docstrings = (
'"""simple doc string"""',
'''"""A more complex one
with some newlines"""''',
'''"""Foo bar baz

empty newline"""''',
'"""With some \t"""',
'"""Foo "bar" baz """',
)

keywords = ("class", "def", "async def")
isidentical marked this conversation as resolved.
Show resolved Hide resolved

for keyword in keywords:
for docstring in docstrings:
self.check_src_roundtrip(f"{keyword} foo():\n {docstring}")

class DirectoryTestCase(ASTTestCase):
"""Test roundtrip behaviour on all files in Lib and Lib/test."""
Expand Down