]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-117266: Fix crashes on user-created AST subclasses (GH-117276)
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Thu, 28 Mar 2024 10:30:31 +0000 (04:30 -0600)
committerGitHub <noreply@github.com>
Thu, 28 Mar 2024 10:30:31 +0000 (11:30 +0100)
Fix crashes on user-created AST subclasses

Lib/test/test_ast.py
Misc/NEWS.d/next/Core and Builtins/2024-03-26-17-22-38.gh-issue-117266.Kwh79O.rst [new file with mode: 0644]
Parser/asdl_c.py
Python/Python-ast.c

index 7cecf319e3638f0f7443eee04ab888faba359a83..3929e4e00d59c221c12403787613b97f0ea5e09a 100644 (file)
@@ -2916,6 +2916,47 @@ class ASTConstructorTests(unittest.TestCase):
         self.assertEqual(node.name, 'foo')
         self.assertEqual(node.decorator_list, [])
 
+    def test_custom_subclass(self):
+        class NoInit(ast.AST):
+            pass
+
+        obj = NoInit()
+        self.assertIsInstance(obj, NoInit)
+        self.assertEqual(obj.__dict__, {})
+
+        class Fields(ast.AST):
+            _fields = ('a',)
+
+        with self.assertWarnsRegex(DeprecationWarning,
+                                   r"Fields provides _fields but not _field_types."):
+            obj = Fields()
+        with self.assertRaises(AttributeError):
+            obj.a
+        obj = Fields(a=1)
+        self.assertEqual(obj.a, 1)
+
+        class FieldsAndTypes(ast.AST):
+            _fields = ('a',)
+            _field_types = {'a': int | None}
+            a: int | None = None
+
+        obj = FieldsAndTypes()
+        self.assertIs(obj.a, None)
+        obj = FieldsAndTypes(a=1)
+        self.assertEqual(obj.a, 1)
+
+        class FieldsAndTypesNoDefault(ast.AST):
+            _fields = ('a',)
+            _field_types = {'a': int}
+
+        with self.assertWarnsRegex(DeprecationWarning,
+                                   r"FieldsAndTypesNoDefault\.__init__ missing 1 required positional argument: 'a'\."):
+            obj = FieldsAndTypesNoDefault()
+        with self.assertRaises(AttributeError):
+            obj.a
+        obj = FieldsAndTypesNoDefault(a=1)
+        self.assertEqual(obj.a, 1)
+
 
 @support.cpython_only
 class ModuleStateTests(unittest.TestCase):
diff --git a/Misc/NEWS.d/next/Core and Builtins/2024-03-26-17-22-38.gh-issue-117266.Kwh79O.rst b/Misc/NEWS.d/next/Core and Builtins/2024-03-26-17-22-38.gh-issue-117266.Kwh79O.rst
new file mode 100644 (file)
index 0000000..5055954
--- /dev/null
@@ -0,0 +1,2 @@
+Fix crashes for certain user-created subclasses of :class:`ast.AST`. Such
+classes are now expected to set the ``_field_types`` attribute.
index 59cc391881ab8668d29651e41e2be914aaec5240..c4df2c52c032bcec8eadfd17025977860ddd81c6 100755 (executable)
@@ -973,11 +973,22 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
     Py_ssize_t size = PySet_Size(remaining_fields);
     PyObject *field_types = NULL, *remaining_list = NULL;
     if (size > 0) {
-        if (!PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), &_Py_ID(_field_types),
-                                      &field_types)) {
+        if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), &_Py_ID(_field_types),
+                                     &field_types) < 0) {
             res = -1;
             goto cleanup;
         }
+        if (field_types == NULL) {
+            if (PyErr_WarnFormat(
+                PyExc_DeprecationWarning, 1,
+                "%.400s provides _fields but not _field_types. "
+                "This will become an error in Python 3.15.",
+                Py_TYPE(self)->tp_name
+            ) < 0) {
+                res = -1;
+            }
+            goto cleanup;
+        }
         remaining_list = PySequence_List(remaining_fields);
         if (!remaining_list) {
             goto set_remaining_cleanup;
index 7b591ddaa298695c69113cb817c40c307c8dc571..60b46263a0d329978d3a00fc2a954a7ea3421dfe 100644 (file)
@@ -5119,11 +5119,22 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
     Py_ssize_t size = PySet_Size(remaining_fields);
     PyObject *field_types = NULL, *remaining_list = NULL;
     if (size > 0) {
-        if (!PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), &_Py_ID(_field_types),
-                                      &field_types)) {
+        if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), &_Py_ID(_field_types),
+                                     &field_types) < 0) {
             res = -1;
             goto cleanup;
         }
+        if (field_types == NULL) {
+            if (PyErr_WarnFormat(
+                PyExc_DeprecationWarning, 1,
+                "%.400s provides _fields but not _field_types. "
+                "This will become an error in Python 3.15.",
+                Py_TYPE(self)->tp_name
+            ) < 0) {
+                res = -1;
+            }
+            goto cleanup;
+        }
         remaining_list = PySequence_List(remaining_fields);
         if (!remaining_list) {
             goto set_remaining_cleanup;