]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-38870: Simplify sequence interleaves in ast.unparse (GH-17892)
authorBatuhan Taşkaya <47358913+isidentical@users.noreply.github.com>
Mon, 9 Mar 2020 20:27:03 +0000 (23:27 +0300)
committerGitHub <noreply@github.com>
Mon, 9 Mar 2020 20:27:03 +0000 (20:27 +0000)
Lib/ast.py
Lib/test/test_unparse.py

index 2719f6ff7ac5938d029e89364eb1c97465c232c3..9a3d3806eb8ca7519a625d974e234f52790a14f9 100644 (file)
@@ -613,6 +613,16 @@ class _Unparser(NodeVisitor):
                 inter()
                 f(x)
 
+    def items_view(self, traverser, items):
+        """Traverse and separate the given *items* with a comma and append it to
+        the buffer. If *items* is a single item sequence, a trailing comma
+        will be added."""
+        if len(items) == 1:
+            traverser(items[0])
+            self.write(",")
+        else:
+            self.interleave(lambda: self.write(", "), traverser, items)
+
     def fill(self, text=""):
         """Indent a piece of text and append it, according to the current
         indentation level"""
@@ -1020,11 +1030,7 @@ class _Unparser(NodeVisitor):
         value = node.value
         if isinstance(value, tuple):
             with self.delimit("(", ")"):
-                if len(value) == 1:
-                    self._write_constant(value[0])
-                    self.write(",")
-                else:
-                    self.interleave(lambda: self.write(", "), self._write_constant, value)
+                self.items_view(self._write_constant, value)
         elif value is ...:
             self.write("...")
         else:
@@ -1116,12 +1122,7 @@ class _Unparser(NodeVisitor):
 
     def visit_Tuple(self, node):
         with self.delimit("(", ")"):
-            if len(node.elts) == 1:
-                elt = node.elts[0]
-                self.traverse(elt)
-                self.write(",")
-            else:
-                self.interleave(lambda: self.write(", "), self.traverse, node.elts)
+            self.items_view(self.traverse, node.elts)
 
     unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
     unop_precedence = {
@@ -1264,12 +1265,7 @@ class _Unparser(NodeVisitor):
             if (isinstance(node.slice, Index)
                     and isinstance(node.slice.value, Tuple)
                     and node.slice.value.elts):
-                if len(node.slice.value.elts) == 1:
-                    elt = node.slice.value.elts[0]
-                    self.traverse(elt)
-                    self.write(",")
-                else:
-                    self.interleave(lambda: self.write(", "), self.traverse, node.slice.value.elts)
+                self.items_view(self.traverse, node.slice.value.elts)
             else:
                 self.traverse(node.slice)
 
@@ -1296,12 +1292,7 @@ class _Unparser(NodeVisitor):
             self.traverse(node.step)
 
     def visit_ExtSlice(self, node):
-        if len(node.dims) == 1:
-            elt = node.dims[0]
-            self.traverse(elt)
-            self.write(",")
-        else:
-            self.interleave(lambda: self.write(", "), self.traverse, node.dims)
+        self.items_view(self.traverse, node.dims)
 
     def visit_arg(self, node):
         self.write(node.arg)
index d33f32e2a7fe936fdf8c7cb104177f3731844e01..3d87cfb6daeef888368729e887d06ba169ed8f3e 100644 (file)
@@ -280,6 +280,20 @@ class UnparseTestCase(ASTTestCase):
         self.check_ast_roundtrip(r"""{**{'y': 2}, 'x': 1}""")
         self.check_ast_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""")
 
+    def test_ext_slices(self):
+        self.check_ast_roundtrip("a[i]")
+        self.check_ast_roundtrip("a[i,]")
+        self.check_ast_roundtrip("a[i, j]")
+        self.check_ast_roundtrip("a[()]")
+        self.check_ast_roundtrip("a[i:j]")
+        self.check_ast_roundtrip("a[:j]")
+        self.check_ast_roundtrip("a[i:]")
+        self.check_ast_roundtrip("a[i:j:k]")
+        self.check_ast_roundtrip("a[:j:k]")
+        self.check_ast_roundtrip("a[i::k]")
+        self.check_ast_roundtrip("a[i:j,]")
+        self.check_ast_roundtrip("a[i:j, k]")
+
     def test_invalid_raise(self):
         self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X")))
 
@@ -310,6 +324,12 @@ class UnparseTestCase(ASTTestCase):
             # check as Module docstrings for easy testing
             self.check_ast_roundtrip(f"'{docstring}'")
 
+    def test_constant_tuples(self):
+        self.check_src_roundtrip(ast.Constant(value=(1,), kind=None), "(1,)")
+        self.check_src_roundtrip(
+            ast.Constant(value=(1, 2, 3), kind=None), "(1, 2, 3)"
+        )
+
 
 class CosmeticTestCase(ASTTestCase):
     """Test if there are cosmetic issues caused by unnecesary additions"""
@@ -344,20 +364,6 @@ class CosmeticTestCase(ASTTestCase):
         self.check_src_roundtrip("call((yield x))")
         self.check_src_roundtrip("return x + (yield x)")
 
-    def test_subscript(self):
-        self.check_src_roundtrip("a[i]")
-        self.check_src_roundtrip("a[i,]")
-        self.check_src_roundtrip("a[i, j]")
-        self.check_src_roundtrip("a[()]")
-        self.check_src_roundtrip("a[i:j]")
-        self.check_src_roundtrip("a[:j]")
-        self.check_src_roundtrip("a[i:]")
-        self.check_src_roundtrip("a[i:j:k]")
-        self.check_src_roundtrip("a[:j:k]")
-        self.check_src_roundtrip("a[i::k]")
-        self.check_src_roundtrip("a[i:j,]")
-        self.check_src_roundtrip("a[i:j, k]")
-
     def test_docstrings(self):
         docstrings = (
             '"""simple doc string"""',