]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-43417: Better buffer handling for ast.unparse (GH-24772)
authorBatuhan Taskaya <isidentical@gmail.com>
Sat, 8 May 2021 23:32:04 +0000 (02:32 +0300)
committerGitHub <noreply@github.com>
Sat, 8 May 2021 23:32:04 +0000 (02:32 +0300)
Lib/ast.py
Lib/test/test_unparse.py

index 66bcee8a252a0457b5844dda5a0c03e2712ac417..18163d6b7bd1634840d98588afd95db064e0d9eb 100644 (file)
@@ -678,7 +678,6 @@ class _Unparser(NodeVisitor):
 
     def __init__(self, *, _avoid_backslashes=False):
         self._source = []
-        self._buffer = []
         self._precedences = {}
         self._type_ignores = {}
         self._indent = 0
@@ -721,14 +720,15 @@ class _Unparser(NodeVisitor):
         """Append a piece of text"""
         self._source.append(text)
 
-    def buffer_writer(self, text):
-        self._buffer.append(text)
+    @contextmanager
+    def buffered(self, buffer = None):
+        if buffer is None:
+            buffer = []
 
-    @property
-    def buffer(self):
-        value = "".join(self._buffer)
-        self._buffer.clear()
-        return value
+        original_source = self._source
+        self._source = buffer
+        yield buffer
+        self._source = original_source
 
     @contextmanager
     def block(self, *, extra = None):
@@ -1127,9 +1127,9 @@ class _Unparser(NodeVisitor):
     def visit_JoinedStr(self, node):
         self.write("f")
         if self._avoid_backslashes:
-            self._fstring_JoinedStr(node, self.buffer_writer)
-            self._write_str_avoiding_backslashes(self.buffer)
-            return
+            with self.buffered() as buffer:
+                self._write_fstring_inner(node)
+            return self._write_str_avoiding_backslashes("".join(buffer))
 
         # If we don't need to avoid backslashes globally (i.e., we only need
         # to avoid them inside FormattedValues), it's cosmetically preferred
@@ -1137,60 +1137,62 @@ class _Unparser(NodeVisitor):
         # for cases like: f"{x}\n". To accomplish this, we keep track of what
         # in our buffer corresponds to FormattedValues and what corresponds to
         # Constant parts of the f-string, and allow escapes accordingly.
-        buffer = []
+        fstring_parts = []
         for value in node.values:
-            meth = getattr(self, "_fstring_" + type(value).__name__)
-            meth(value, self.buffer_writer)
-            buffer.append((self.buffer, isinstance(value, Constant)))
-        new_buffer = []
-        quote_types = _ALL_QUOTES
-        for value, is_constant in buffer:
-            # Repeatedly narrow down the list of possible quote_types
+            with self.buffered() as buffer:
+                self._write_fstring_inner(value)
+            fstring_parts.append(
+                ("".join(buffer), isinstance(value, Constant))
+            )
+
+        new_fstring_parts = []
+        quote_types = list(_ALL_QUOTES)
+        for value, is_constant in fstring_parts:
             value, quote_types = self._str_literal_helper(
-                value, quote_types=quote_types,
-                escape_special_whitespace=is_constant
+                value,
+                quote_types=quote_types,
+                escape_special_whitespace=is_constant,
             )
-            new_buffer.append(value)
-        value = "".join(new_buffer)
+            new_fstring_parts.append(value)
+
+        value = "".join(new_fstring_parts)
         quote_type = quote_types[0]
         self.write(f"{quote_type}{value}{quote_type}")
 
+    def _write_fstring_inner(self, node):
+        if isinstance(node, JoinedStr):
+            # for both the f-string itself, and format_spec
+            for value in node.values:
+                self._write_fstring_inner(value)
+        elif isinstance(node, Constant) and isinstance(node.value, str):
+            value = node.value.replace("{", "{{").replace("}", "}}")
+            self.write(value)
+        elif isinstance(node, FormattedValue):
+            self.visit_FormattedValue(node)
+        else:
+            raise ValueError(f"Unexpected node inside JoinedStr, {node!r}")
+
     def visit_FormattedValue(self, node):
-        self.write("f")
-        self._fstring_FormattedValue(node, self.buffer_writer)
-        self._write_str_avoiding_backslashes(self.buffer)
+        def unparse_inner(inner):
+            unparser = type(self)(_avoid_backslashes=True)
+            unparser.set_precedence(_Precedence.TEST.next(), inner)
+            return unparser.visit(inner)
 
-    def _fstring_JoinedStr(self, node, write):
-        for value in node.values:
-            meth = getattr(self, "_fstring_" + type(value).__name__)
-            meth(value, write)
-
-    def _fstring_Constant(self, node, write):
-        if not isinstance(node.value, str):
-            raise ValueError("Constants inside JoinedStr should be a string.")
-        value = node.value.replace("{", "{{").replace("}", "}}")
-        write(value)
-
-    def _fstring_FormattedValue(self, node, write):
-        write("{")
-        unparser = type(self)(_avoid_backslashes=True)
-        unparser.set_precedence(_Precedence.TEST.next(), node.value)
-        expr = unparser.visit(node.value)
-        if expr.startswith("{"):
-            write(" ")  # Separate pair of opening brackets as "{ {"
-        if "\\" in expr:
-            raise ValueError("Unable to avoid backslash in f-string expression part")
-        write(expr)
-        if node.conversion != -1:
-            conversion = chr(node.conversion)
-            if conversion not in "sra":
-                raise ValueError("Unknown f-string conversion.")
-            write(f"!{conversion}")
-        if node.format_spec:
-            write(":")
-            meth = getattr(self, "_fstring_" + type(node.format_spec).__name__)
-            meth(node.format_spec, write)
-        write("}")
+        with self.delimit("{", "}"):
+            expr = unparse_inner(node.value)
+            if "\\" in expr:
+                raise ValueError(
+                    "Unable to avoid backslash in f-string expression part"
+                )
+            if expr.startswith("{"):
+                # Separate pair of opening brackets as "{ {"
+                self.write(" ")
+            self.write(expr)
+            if node.conversion != -1:
+                self.write(f"!{chr(node.conversion)}")
+            if node.format_spec:
+                self.write(":")
+                self._write_fstring_inner(node.format_spec)
 
     def visit_Name(self, node):
         self.write(node.id)
index 9f67b49f3a6b2b35cec0c9335593a72671b9c3ad..534431bc9698357ea47cfdf8321cc0e3a7f10208 100644 (file)
@@ -149,6 +149,27 @@ class UnparseTestCase(ASTTestCase):
     # Tests for specific bugs found in earlier versions of unparse
 
     def test_fstrings(self):
+        self.check_ast_roundtrip("f'a'")
+        self.check_ast_roundtrip("f'{{}}'")
+        self.check_ast_roundtrip("f'{{5}}'")
+        self.check_ast_roundtrip("f'{{5}}5'")
+        self.check_ast_roundtrip("f'X{{}}X'")
+        self.check_ast_roundtrip("f'{a}'")
+        self.check_ast_roundtrip("f'{ {1:2}}'")
+        self.check_ast_roundtrip("f'a{a}a'")
+        self.check_ast_roundtrip("f'a{a}{a}a'")
+        self.check_ast_roundtrip("f'a{a}a{a}a'")
+        self.check_ast_roundtrip("f'{a!r}x{a!s}12{{}}{a!a}'")
+        self.check_ast_roundtrip("f'{a:10}'")
+        self.check_ast_roundtrip("f'{a:100_000{10}}'")
+        self.check_ast_roundtrip("f'{a!r:10}'")
+        self.check_ast_roundtrip("f'{a:a{b}10}'")
+        self.check_ast_roundtrip(
+                "f'a{b}{c!s}{d!r}{e!a}{f:a}{g:a{b}}{h!s:a}"
+                "{j!s:{a}b}{k!s:a{b}c}{l!a:{b}c{d}}{x+y=}'"
+        )
+
+    def test_fstrings_special_chars(self):
         # See issue 25180
         self.check_ast_roundtrip(r"""f'{f"{0}"*3}'""")
         self.check_ast_roundtrip(r"""f'{f"{y}"*3}'""")
@@ -323,15 +344,13 @@ class UnparseTestCase(ASTTestCase):
     def test_invalid_raise(self):
         self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X")))
 
-    def test_invalid_fstring_constant(self):
-        self.check_invalid(ast.JoinedStr(values=[ast.Constant(value=100)]))
-
-    def test_invalid_fstring_conversion(self):
+    def test_invalid_fstring_value(self):
         self.check_invalid(
-            ast.FormattedValue(
-                value=ast.Constant(value="a", kind=None),
-                conversion=ord("Y"),  # random character
-                format_spec=None,
+            ast.JoinedStr(
+                values=[
+                    ast.Name(id="test"),
+                    ast.Constant(value="test")
+                ]
             )
         )