]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-121332: Make AST node constructor check _attributes instead of hardcoding attribut...
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Thu, 11 Jul 2024 14:34:53 +0000 (07:34 -0700)
committerGitHub <noreply@github.com>
Thu, 11 Jul 2024 14:34:53 +0000 (14:34 +0000)
Lib/test/test_ast.py
Misc/NEWS.d/next/Library/2024-07-03-07-25-21.gh-issue-121332.Iz6FEq.rst [new file with mode: 0644]
Parser/asdl_c.py
Python/Python-ast.c

index eb3aefd5c262f6a1085641f57323b12dfbd20459..497c3f261a1fcaecc84065b637d066b16941cb42 100644 (file)
@@ -1386,15 +1386,7 @@ class CopyTests(unittest.TestCase):
         self.assertEqual(node.y, 1)
 
         y = object()
-        # custom attributes are currently not supported and raise a warning
-        # because the allowed attributes are hard-coded !
-        msg = (
-            "MyNode.__init__ got an unexpected keyword argument 'y'. "
-            "Support for arbitrary keyword arguments is deprecated and "
-            "will be removed in Python 3.15"
-        )
-        with self.assertWarnsRegex(DeprecationWarning, re.escape(msg)):
-            repl = copy.replace(node, y=y)
+        repl = copy.replace(node, y=y)
         # assert that there is no side-effect
         self.assertEqual(node.x, 0)
         self.assertEqual(node.y, 1)
@@ -3250,6 +3242,18 @@ class ASTConstructorTests(unittest.TestCase):
         obj = FieldsAndTypes(a=1)
         self.assertEqual(obj.a, 1)
 
+    def test_custom_attributes(self):
+        class MyAttrs(ast.AST):
+            _attributes = ("a", "b")
+
+        obj = MyAttrs(a=1, b=2)
+        self.assertEqual(obj.a, 1)
+        self.assertEqual(obj.b, 2)
+
+        with self.assertWarnsRegex(DeprecationWarning,
+                                   r"MyAttrs.__init__ got an unexpected keyword argument 'c'."):
+            obj = MyAttrs(c=3)
+
     def test_fields_and_types_no_default(self):
         class FieldsAndTypesNoDefault(ast.AST):
             _fields = ('a',)
diff --git a/Misc/NEWS.d/next/Library/2024-07-03-07-25-21.gh-issue-121332.Iz6FEq.rst b/Misc/NEWS.d/next/Library/2024-07-03-07-25-21.gh-issue-121332.Iz6FEq.rst
new file mode 100644 (file)
index 0000000..480f27e
--- /dev/null
@@ -0,0 +1,4 @@
+Fix constructor of :mod:`ast` nodes with custom ``_attributes``. Previously,
+passing custom attributes would raise a :py:exc:`DeprecationWarning`. Passing
+arguments to the constructor that are not in ``_fields`` or ``_attributes``
+remains deprecated. Patch by Jelle Zijlstra.
index f3667801782f2b7cea8a95e1412eac261b005d6d..e6867f138a5ccb8cf7b0dd9edc1a4820c9ac1bbb 100755 (executable)
@@ -880,7 +880,7 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
 
     Py_ssize_t i, numfields = 0;
     int res = -1;
-    PyObject *key, *value, *fields, *remaining_fields = NULL;
+    PyObject *key, *value, *fields, *attributes = NULL, *remaining_fields = NULL;
     if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
         goto cleanup;
     }
@@ -947,22 +947,32 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
                     goto cleanup;
                 }
             }
-            else if (
-                PyUnicode_CompareWithASCIIString(key, "lineno") != 0 &&
-                PyUnicode_CompareWithASCIIString(key, "col_offset") != 0 &&
-                PyUnicode_CompareWithASCIIString(key, "end_lineno") != 0 &&
-                PyUnicode_CompareWithASCIIString(key, "end_col_offset") != 0
-            ) {
-                if (PyErr_WarnFormat(
-                    PyExc_DeprecationWarning, 1,
-                    "%.400s.__init__ got an unexpected keyword argument '%U'. "
-                    "Support for arbitrary keyword arguments is deprecated "
-                    "and will be removed in Python 3.15.",
-                    Py_TYPE(self)->tp_name, key
-                ) < 0) {
+            else {
+                // Lazily initialize "attributes"
+                if (attributes == NULL) {
+                    attributes = PyObject_GetAttr((PyObject*)Py_TYPE(self), state->_attributes);
+                    if (attributes == NULL) {
+                        res = -1;
+                        goto cleanup;
+                    }
+                }
+                int contains = PySequence_Contains(attributes, key);
+                if (contains == -1) {
                     res = -1;
                     goto cleanup;
                 }
+                else if (contains == 0) {
+                    if (PyErr_WarnFormat(
+                        PyExc_DeprecationWarning, 1,
+                        "%.400s.__init__ got an unexpected keyword argument '%U'. "
+                        "Support for arbitrary keyword arguments is deprecated "
+                        "and will be removed in Python 3.15.",
+                        Py_TYPE(self)->tp_name, key
+                    ) < 0) {
+                        res = -1;
+                        goto cleanup;
+                    }
+                }
             }
             res = PyObject_SetAttr(self, key, value);
             if (res < 0) {
@@ -1045,6 +1055,7 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
         Py_DECREF(field_types);
     }
   cleanup:
+    Py_XDECREF(attributes);
     Py_XDECREF(fields);
     Py_XDECREF(remaining_fields);
     return res;
index cca2ee409e797871920315496d9361966a8dd11b..4d0db457a8b17281aae2ec4a75d9887bcf8563a6 100644 (file)
@@ -5081,7 +5081,7 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
 
     Py_ssize_t i, numfields = 0;
     int res = -1;
-    PyObject *key, *value, *fields, *remaining_fields = NULL;
+    PyObject *key, *value, *fields, *attributes = NULL, *remaining_fields = NULL;
     if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
         goto cleanup;
     }
@@ -5148,22 +5148,32 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
                     goto cleanup;
                 }
             }
-            else if (
-                PyUnicode_CompareWithASCIIString(key, "lineno") != 0 &&
-                PyUnicode_CompareWithASCIIString(key, "col_offset") != 0 &&
-                PyUnicode_CompareWithASCIIString(key, "end_lineno") != 0 &&
-                PyUnicode_CompareWithASCIIString(key, "end_col_offset") != 0
-            ) {
-                if (PyErr_WarnFormat(
-                    PyExc_DeprecationWarning, 1,
-                    "%.400s.__init__ got an unexpected keyword argument '%U'. "
-                    "Support for arbitrary keyword arguments is deprecated "
-                    "and will be removed in Python 3.15.",
-                    Py_TYPE(self)->tp_name, key
-                ) < 0) {
+            else {
+                // Lazily initialize "attributes"
+                if (attributes == NULL) {
+                    attributes = PyObject_GetAttr((PyObject*)Py_TYPE(self), state->_attributes);
+                    if (attributes == NULL) {
+                        res = -1;
+                        goto cleanup;
+                    }
+                }
+                int contains = PySequence_Contains(attributes, key);
+                if (contains == -1) {
                     res = -1;
                     goto cleanup;
                 }
+                else if (contains == 0) {
+                    if (PyErr_WarnFormat(
+                        PyExc_DeprecationWarning, 1,
+                        "%.400s.__init__ got an unexpected keyword argument '%U'. "
+                        "Support for arbitrary keyword arguments is deprecated "
+                        "and will be removed in Python 3.15.",
+                        Py_TYPE(self)->tp_name, key
+                    ) < 0) {
+                        res = -1;
+                        goto cleanup;
+                    }
+                }
             }
             res = PyObject_SetAttr(self, key, value);
             if (res < 0) {
@@ -5246,6 +5256,7 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
         Py_DECREF(field_types);
     }
   cleanup:
+    Py_XDECREF(attributes);
     Py_XDECREF(fields);
     Py_XDECREF(remaining_fields);
     return res;