]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.13] gh-120108: Fix deepcopying of AST trees with .parent attributes (GH-120114...
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Tue, 25 Jun 2024 15:39:29 +0000 (08:39 -0700)
committerGitHub <noreply@github.com>
Tue, 25 Jun 2024 15:39:29 +0000 (15:39 +0000)
(cherry picked from commit 42b2c9d78da7ebd6bd5925a4d4c78aec3c9e78e6)

Lib/test/test_ast.py
Misc/NEWS.d/next/Library/2024-06-05-08-02-46.gh-issue-120108.4U9BL8.rst [new file with mode: 0644]
Parser/asdl_c.py
Python/Python-ast.c

index 5422c861ffb5c0d1e880f8a3168e1a5281a4d320..93bd5dec6eac7418a6418f2f7085753556dc801e 100644 (file)
@@ -1,5 +1,6 @@
 import ast
 import builtins
+import copy
 import dis
 import enum
 import os
@@ -23,7 +24,7 @@ from test.support import os_helper, script_helper
 from test.support.ast_helper import ASTTestMixin
 
 def to_tuple(t):
-    if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis:
+    if t is None or isinstance(t, (str, int, complex, float, bytes)) or t is Ellipsis:
         return t
     elif isinstance(t, list):
         return [to_tuple(e) for e in t]
@@ -971,15 +972,6 @@ class AST_Tests(unittest.TestCase):
         x = ast.Sub()
         self.assertEqual(x._fields, ())
 
-    def test_pickling(self):
-        import pickle
-
-        for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
-            for ast in (compile(i, "?", "exec", 0x400) for i in exec_tests):
-                with self.subTest(ast=ast, protocol=protocol):
-                    ast2 = pickle.loads(pickle.dumps(ast, protocol))
-                    self.assertEqual(to_tuple(ast2), to_tuple(ast))
-
     def test_invalid_sum(self):
         pos = dict(lineno=2, col_offset=3)
         m = ast.Module([ast.Expr(ast.expr(**pos), **pos)], [])
@@ -1222,6 +1214,80 @@ class AST_Tests(unittest.TestCase):
         for node, attr, source in tests:
             self.assert_none_check(node, attr, source)
 
+
+class CopyTests(unittest.TestCase):
+    """Test copying and pickling AST nodes."""
+
+    def test_pickling(self):
+        import pickle
+
+        for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
+            for code in exec_tests:
+                with self.subTest(code=code, protocol=protocol):
+                    tree = compile(code, "?", "exec", 0x400)
+                    ast2 = pickle.loads(pickle.dumps(tree, protocol))
+                    self.assertEqual(to_tuple(ast2), to_tuple(tree))
+
+    def test_copy_with_parents(self):
+        # gh-120108
+        code = """
+        ('',)
+        while i < n:
+            if ch == '':
+                ch = format[i]
+                if ch == '':
+                    if freplace is None:
+                        '' % getattr(object)
+                elif ch == '':
+                    if zreplace is None:
+                        if hasattr:
+                            offset = object.utcoffset()
+                            if offset is not None:
+                                if offset.days < 0:
+                                    offset = -offset
+                                h = divmod(timedelta(hours=0))
+                                if u:
+                                    zreplace = '' % (sign,)
+                                elif s:
+                                    zreplace = '' % (sign,)
+                                else:
+                                    zreplace = '' % (sign,)
+                elif ch == '':
+                    if Zreplace is None:
+                        Zreplace = ''
+                        if hasattr(object):
+                            s = object.tzname()
+                            if s is not None:
+                                Zreplace = s.replace('')
+                    newformat.append(Zreplace)
+                else:
+                    push('')
+            else:
+                push(ch)
+
+        """
+        tree = ast.parse(textwrap.dedent(code))
+        for node in ast.walk(tree):
+            for child in ast.iter_child_nodes(node):
+                child.parent = node
+        try:
+            with support.infinite_recursion(200):
+                tree2 = copy.deepcopy(tree)
+        finally:
+            # Singletons like ast.Load() are shared; make sure we don't
+            # leave them mutated after this test.
+            for node in ast.walk(tree):
+                if hasattr(node, "parent"):
+                    del node.parent
+
+        for node in ast.walk(tree2):
+            for child in ast.iter_child_nodes(node):
+                if hasattr(child, "parent") and not isinstance(child, (
+                    ast.expr_context, ast.boolop, ast.unaryop, ast.cmpop, ast.operator,
+                )):
+                    self.assertEqual(to_tuple(child.parent), to_tuple(node))
+
+
 class ASTHelpers_Test(unittest.TestCase):
     maxDiff = None
 
diff --git a/Misc/NEWS.d/next/Library/2024-06-05-08-02-46.gh-issue-120108.4U9BL8.rst b/Misc/NEWS.d/next/Library/2024-06-05-08-02-46.gh-issue-120108.4U9BL8.rst
new file mode 100644 (file)
index 0000000..e310695
--- /dev/null
@@ -0,0 +1,2 @@
+Fix calling :func:`copy.deepcopy` on :mod:`ast` trees that have been
+modified to have references to parent nodes. Patch by Jelle Zijlstra.
index 9961d23629abc580e1f772970e3289614c8ba334..e338656a5b1eb9647380edf9885f83985c6ba958 100755 (executable)
@@ -1064,17 +1064,22 @@ ast_type_reduce(PyObject *self, PyObject *unused)
         return NULL;
     }
 
-    PyObject *dict = NULL, *fields = NULL, *remaining_fields = NULL,
-             *remaining_dict = NULL, *positional_args = NULL;
+    PyObject *dict = NULL, *fields = NULL, *positional_args = NULL;
     if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) {
         return NULL;
     }
     PyObject *result = NULL;
     if (dict) {
-        // Serialize the fields as positional args if possible, because if we
-        // serialize them as a dict, during unpickling they are set only *after*
-        // the object is constructed, which will now trigger a DeprecationWarning
-        // if the AST type has required fields.
+        // Unpickling (or copying) works as follows:
+        // - Construct the object with only positional arguments
+        // - Set the fields from the dict
+        // We have two constraints:
+        // - We must set all the required fields in the initial constructor call,
+        //   or the unpickling or deepcopying of the object will trigger DeprecationWarnings.
+        // - We must not include child nodes in the positional args, because
+        //   that may trigger runaway recursion during copying (gh-120108).
+        // To satisfy both constraints, we set all the fields to None in the
+        // initial list of positional args, and then set the fields from the dict.
         if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
             goto cleanup;
         }
@@ -1084,11 +1089,6 @@ ast_type_reduce(PyObject *self, PyObject *unused)
                 Py_DECREF(dict);
                 goto cleanup;
             }
-            remaining_dict = PyDict_Copy(dict);
-            Py_DECREF(dict);
-            if (!remaining_dict) {
-                goto cleanup;
-            }
             positional_args = PyList_New(0);
             if (!positional_args) {
                 goto cleanup;
@@ -1099,7 +1099,7 @@ ast_type_reduce(PyObject *self, PyObject *unused)
                     goto cleanup;
                 }
                 PyObject *value;
-                int rc = PyDict_Pop(remaining_dict, name, &value);
+                int rc = PyDict_GetItemRef(dict, name, &value);
                 Py_DECREF(name);
                 if (rc < 0) {
                     goto cleanup;
@@ -1107,7 +1107,7 @@ ast_type_reduce(PyObject *self, PyObject *unused)
                 if (!value) {
                     break;
                 }
-                rc = PyList_Append(positional_args, value);
+                rc = PyList_Append(positional_args, Py_None);
                 Py_DECREF(value);
                 if (rc < 0) {
                     goto cleanup;
@@ -1117,8 +1117,7 @@ ast_type_reduce(PyObject *self, PyObject *unused)
             if (!args_tuple) {
                 goto cleanup;
             }
-            result = Py_BuildValue("ONO", Py_TYPE(self), args_tuple,
-                                   remaining_dict);
+            result = Py_BuildValue("ONN", Py_TYPE(self), args_tuple, dict);
         }
         else {
             result = Py_BuildValue("O()N", Py_TYPE(self), dict);
@@ -1129,8 +1128,6 @@ ast_type_reduce(PyObject *self, PyObject *unused)
     }
 cleanup:
     Py_XDECREF(fields);
-    Py_XDECREF(remaining_fields);
-    Py_XDECREF(remaining_dict);
     Py_XDECREF(positional_args);
     return result;
 }
index 7aa1c5119d8f2837298a0f9ad4288e4e70a0a271..01ffea1869350b523c52ea5464d832898b2975e0 100644 (file)
@@ -5263,17 +5263,22 @@ ast_type_reduce(PyObject *self, PyObject *unused)
         return NULL;
     }
 
-    PyObject *dict = NULL, *fields = NULL, *remaining_fields = NULL,
-             *remaining_dict = NULL, *positional_args = NULL;
+    PyObject *dict = NULL, *fields = NULL, *positional_args = NULL;
     if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) {
         return NULL;
     }
     PyObject *result = NULL;
     if (dict) {
-        // Serialize the fields as positional args if possible, because if we
-        // serialize them as a dict, during unpickling they are set only *after*
-        // the object is constructed, which will now trigger a DeprecationWarning
-        // if the AST type has required fields.
+        // Unpickling (or copying) works as follows:
+        // - Construct the object with only positional arguments
+        // - Set the fields from the dict
+        // We have two constraints:
+        // - We must set all the required fields in the initial constructor call,
+        //   or the unpickling or deepcopying of the object will trigger DeprecationWarnings.
+        // - We must not include child nodes in the positional args, because
+        //   that may trigger runaway recursion during copying (gh-120108).
+        // To satisfy both constraints, we set all the fields to None in the
+        // initial list of positional args, and then set the fields from the dict.
         if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
             goto cleanup;
         }
@@ -5283,11 +5288,6 @@ ast_type_reduce(PyObject *self, PyObject *unused)
                 Py_DECREF(dict);
                 goto cleanup;
             }
-            remaining_dict = PyDict_Copy(dict);
-            Py_DECREF(dict);
-            if (!remaining_dict) {
-                goto cleanup;
-            }
             positional_args = PyList_New(0);
             if (!positional_args) {
                 goto cleanup;
@@ -5298,7 +5298,7 @@ ast_type_reduce(PyObject *self, PyObject *unused)
                     goto cleanup;
                 }
                 PyObject *value;
-                int rc = PyDict_Pop(remaining_dict, name, &value);
+                int rc = PyDict_GetItemRef(dict, name, &value);
                 Py_DECREF(name);
                 if (rc < 0) {
                     goto cleanup;
@@ -5306,7 +5306,7 @@ ast_type_reduce(PyObject *self, PyObject *unused)
                 if (!value) {
                     break;
                 }
-                rc = PyList_Append(positional_args, value);
+                rc = PyList_Append(positional_args, Py_None);
                 Py_DECREF(value);
                 if (rc < 0) {
                     goto cleanup;
@@ -5316,8 +5316,7 @@ ast_type_reduce(PyObject *self, PyObject *unused)
             if (!args_tuple) {
                 goto cleanup;
             }
-            result = Py_BuildValue("ONO", Py_TYPE(self), args_tuple,
-                                   remaining_dict);
+            result = Py_BuildValue("ONN", Py_TYPE(self), args_tuple, dict);
         }
         else {
             result = Py_BuildValue("O()N", Py_TYPE(self), dict);
@@ -5328,8 +5327,6 @@ ast_type_reduce(PyObject *self, PyObject *unused)
     }
 cleanup:
     Py_XDECREF(fields);
-    Py_XDECREF(remaining_fields);
-    Py_XDECREF(remaining_dict);
     Py_XDECREF(positional_args);
     return result;
 }