]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-121141: add support for `copy.replace` to AST nodes (#121162)
authorBénédikt Tran <10796600+picnixz@users.noreply.github.com>
Thu, 4 Jul 2024 03:10:54 +0000 (05:10 +0200)
committerGitHub <noreply@github.com>
Thu, 4 Jul 2024 03:10:54 +0000 (20:10 -0700)
Doc/whatsnew/3.14.rst
Lib/test/test_ast.py
Misc/NEWS.d/next/Library/2024-06-29-15-21-12.gh-issue-121141.4evD6q.rst [new file with mode: 0644]
Parser/asdl_c.py
Python/Python-ast.c

index 9578ba0c9c965770a2a5f25fde06f7939e246200..d02c10ec9cf3f342fc6b4d7d5eeb1307e887400d 100644 (file)
@@ -89,8 +89,12 @@ Improved Modules
 ast
 ---
 
-Added :func:`ast.compare` for comparing two ASTs.
-(Contributed by Batuhan Taskaya and Jeremy Hylton in :issue:`15987`.)
+* Added :func:`ast.compare` for comparing two ASTs.
+  (Contributed by Batuhan Taskaya and Jeremy Hylton in :issue:`15987`.)
+
+* Add support for :func:`copy.replace` for AST nodes.
+
+  (Contributed by Bénédikt Tran in :gh:`121141`.)
 
 os
 --
index fbd196203111595192dcf4224c0ea7751bc18f6e..eb3aefd5c262f6a1085641f57323b12dfbd20459 100644 (file)
@@ -1149,6 +1149,25 @@ class AST_Tests(unittest.TestCase):
 class CopyTests(unittest.TestCase):
     """Test copying and pickling AST nodes."""
 
+    @staticmethod
+    def iter_ast_classes():
+        """Iterate over the (native) subclasses of ast.AST recursively.
+
+        This excludes the special class ast.Index since its constructor
+        returns an integer.
+        """
+        def do(cls):
+            if cls.__module__ != 'ast':
+                return
+            if cls is ast.Index:
+                return
+
+            yield cls
+            for sub in cls.__subclasses__():
+                yield from do(sub)
+
+        yield from do(ast.AST)
+
     def test_pickling(self):
         import pickle
 
@@ -1218,6 +1237,259 @@ class CopyTests(unittest.TestCase):
                 )):
                     self.assertEqual(to_tuple(child.parent), to_tuple(node))
 
+    def test_replace_interface(self):
+        for klass in self.iter_ast_classes():
+            with self.subTest(klass=klass):
+                self.assertTrue(hasattr(klass, '__replace__'))
+
+            fields = set(klass._fields)
+            with self.subTest(klass=klass, fields=fields):
+                node = klass(**dict.fromkeys(fields))
+                # forbid positional arguments in replace()
+                self.assertRaises(TypeError, copy.replace, node, 1)
+                self.assertRaises(TypeError, node.__replace__, 1)
+
+    def test_replace_native(self):
+        for klass in self.iter_ast_classes():
+            fields = set(klass._fields)
+            attributes = set(klass._attributes)
+
+            with self.subTest(klass=klass, fields=fields, attributes=attributes):
+                # use of object() to ensure that '==' and 'is'
+                # behave similarly in ast.compare(node, repl)
+                old_fields = {field: object() for field in fields}
+                old_attrs = {attr: object() for attr in attributes}
+
+                # check shallow copy
+                node = klass(**old_fields)
+                repl = copy.replace(node)
+                self.assertTrue(ast.compare(node, repl, compare_attributes=True))
+                # check when passing using attributes (they may be optional!)
+                node = klass(**old_fields, **old_attrs)
+                repl = copy.replace(node)
+                self.assertTrue(ast.compare(node, repl, compare_attributes=True))
+
+                for field in fields:
+                    # check when we sometimes have attributes and sometimes not
+                    for init_attrs in [{}, old_attrs]:
+                        node = klass(**old_fields, **init_attrs)
+                        # only change a single field (do not change attributes)
+                        new_value = object()
+                        repl = copy.replace(node, **{field: new_value})
+                        for f in fields:
+                            old_value = old_fields[f]
+                            # assert that there is no side-effect
+                            self.assertIs(getattr(node, f), old_value)
+                            # check the changes
+                            if f != field:
+                                self.assertIs(getattr(repl, f), old_value)
+                            else:
+                                self.assertIs(getattr(repl, f), new_value)
+                        self.assertFalse(ast.compare(node, repl, compare_attributes=True))
+
+                for attribute in attributes:
+                    node = klass(**old_fields, **old_attrs)
+                    # only change a single attribute (do not change fields)
+                    new_attr = object()
+                    repl = copy.replace(node, **{attribute: new_attr})
+                    for a in attributes:
+                        old_attr = old_attrs[a]
+                        # assert that there is no side-effect
+                        self.assertIs(getattr(node, a), old_attr)
+                        # check the changes
+                        if a != attribute:
+                            self.assertIs(getattr(repl, a), old_attr)
+                        else:
+                            self.assertIs(getattr(repl, a), new_attr)
+                    self.assertFalse(ast.compare(node, repl, compare_attributes=True))
+
+    def test_replace_accept_known_class_fields(self):
+        nid, ctx = object(), object()
+
+        node = ast.Name(id=nid, ctx=ctx)
+        self.assertIs(node.id, nid)
+        self.assertIs(node.ctx, ctx)
+
+        new_nid = object()
+        repl = copy.replace(node, id=new_nid)
+        # assert that there is no side-effect
+        self.assertIs(node.id, nid)
+        self.assertIs(node.ctx, ctx)
+        # check the changes
+        self.assertIs(repl.id, new_nid)
+        self.assertIs(repl.ctx, node.ctx)  # no changes
+
+    def test_replace_accept_known_class_attributes(self):
+        node = ast.parse('x').body[0].value
+        self.assertEqual(node.id, 'x')
+        self.assertEqual(node.lineno, 1)
+
+        # constructor allows any type so replace() should do the same
+        lineno = object()
+        repl = copy.replace(node, lineno=lineno)
+        # assert that there is no side-effect
+        self.assertEqual(node.lineno, 1)
+        # check the changes
+        self.assertEqual(repl.id, node.id)
+        self.assertEqual(repl.ctx, node.ctx)
+        self.assertEqual(repl.lineno, lineno)
+
+        _, _, state = node.__reduce__()
+        self.assertEqual(state['id'], 'x')
+        self.assertEqual(state['ctx'], node.ctx)
+        self.assertEqual(state['lineno'], 1)
+
+        _, _, state = repl.__reduce__()
+        self.assertEqual(state['id'], 'x')
+        self.assertEqual(state['ctx'], node.ctx)
+        self.assertEqual(state['lineno'], lineno)
+
+    def test_replace_accept_known_custom_class_fields(self):
+        class MyNode(ast.AST):
+            _fields = ('name', 'data')
+            __annotations__ = {'name': str, 'data': object}
+            __match_args__ = ('name', 'data')
+
+        name, data = 'name', object()
+
+        node = MyNode(name, data)
+        self.assertIs(node.name, name)
+        self.assertIs(node.data, data)
+        # check shallow copy
+        repl = copy.replace(node)
+        # assert that there is no side-effect
+        self.assertIs(node.name, name)
+        self.assertIs(node.data, data)
+        # check the shallow copy
+        self.assertIs(repl.name, name)
+        self.assertIs(repl.data, data)
+
+        node = MyNode(name, data)
+        repl_data = object()
+        # replace custom but known field
+        repl = copy.replace(node, data=repl_data)
+        # assert that there is no side-effect
+        self.assertIs(node.name, name)
+        self.assertIs(node.data, data)
+        # check the changes
+        self.assertIs(repl.name, node.name)
+        self.assertIs(repl.data, repl_data)
+
+    def test_replace_accept_known_custom_class_attributes(self):
+        class MyNode(ast.AST):
+            x = 0
+            y = 1
+            _attributes = ('x', 'y')
+
+        node = MyNode()
+        self.assertEqual(node.x, 0)
+        self.assertEqual(node.y, 1)
+
+        y = object()
+        # custom attributes are currently not supported and raise a warning
+        # because the allowed attributes are hard-coded !
+        msg = (
+            "MyNode.__init__ got an unexpected keyword argument 'y'. "
+            "Support for arbitrary keyword arguments is deprecated and "
+            "will be removed in Python 3.15"
+        )
+        with self.assertWarnsRegex(DeprecationWarning, re.escape(msg)):
+            repl = copy.replace(node, y=y)
+        # assert that there is no side-effect
+        self.assertEqual(node.x, 0)
+        self.assertEqual(node.y, 1)
+        # check the changes
+        self.assertEqual(repl.x, 0)
+        self.assertEqual(repl.y, y)
+
+    def test_replace_ignore_known_custom_instance_fields(self):
+        node = ast.parse('x').body[0].value
+        node.extra = extra = object()  # add instance 'extra' field
+        context = node.ctx
+
+        # assert initial values
+        self.assertIs(node.id, 'x')
+        self.assertIs(node.ctx, context)
+        self.assertIs(node.extra, extra)
+        # shallow copy, but drops extra fields
+        repl = copy.replace(node)
+        # assert that there is no side-effect
+        self.assertIs(node.id, 'x')
+        self.assertIs(node.ctx, context)
+        self.assertIs(node.extra, extra)
+        # verify that the 'extra' field is not kept
+        self.assertIs(repl.id, 'x')
+        self.assertIs(repl.ctx, context)
+        self.assertRaises(AttributeError, getattr, repl, 'extra')
+
+        # change known native field
+        repl = copy.replace(node, id='y')
+        # assert that there is no side-effect
+        self.assertIs(node.id, 'x')
+        self.assertIs(node.ctx, context)
+        self.assertIs(node.extra, extra)
+        # verify that the 'extra' field is not kept
+        self.assertIs(repl.id, 'y')
+        self.assertIs(repl.ctx, context)
+        self.assertRaises(AttributeError, getattr, repl, 'extra')
+
+    def test_replace_reject_missing_field(self):
+        # case: warn if deleted field is not replaced
+        node = ast.parse('x').body[0].value
+        context = node.ctx
+        del node.id
+
+        self.assertRaises(AttributeError, getattr, node, 'id')
+        self.assertIs(node.ctx, context)
+        msg = "Name.__replace__ missing 1 keyword argument: 'id'."
+        with self.assertRaisesRegex(TypeError, re.escape(msg)):
+            copy.replace(node)
+        # assert that there is no side-effect
+        self.assertRaises(AttributeError, getattr, node, 'id')
+        self.assertIs(node.ctx, context)
+
+        # case: do not raise if deleted field is replaced
+        node = ast.parse('x').body[0].value
+        context = node.ctx
+        del node.id
+
+        self.assertRaises(AttributeError, getattr, node, 'id')
+        self.assertIs(node.ctx, context)
+        repl = copy.replace(node, id='y')
+        # assert that there is no side-effect
+        self.assertRaises(AttributeError, getattr, node, 'id')
+        self.assertIs(node.ctx, context)
+        self.assertIs(repl.id, 'y')
+        self.assertIs(repl.ctx, context)
+
+    def test_replace_reject_known_custom_instance_fields_commits(self):
+        node = ast.parse('x').body[0].value
+        node.extra = extra = object()  # add instance 'extra' field
+        context = node.ctx
+
+        # explicit rejection of known instance fields
+        self.assertTrue(hasattr(node, 'extra'))
+        msg = "Name.__replace__ got an unexpected keyword argument 'extra'."
+        with self.assertRaisesRegex(TypeError, re.escape(msg)):
+            copy.replace(node, extra=1)
+        # assert that there is no side-effect
+        self.assertIs(node.id, 'x')
+        self.assertIs(node.ctx, context)
+        self.assertIs(node.extra, extra)
+
+    def test_replace_reject_unknown_instance_fields(self):
+        node = ast.parse('x').body[0].value
+        context = node.ctx
+
+        # explicit rejection of unknown extra fields
+        self.assertRaises(AttributeError, getattr, node, 'unknown')
+        msg = "Name.__replace__ got an unexpected keyword argument 'unknown'."
+        with self.assertRaisesRegex(TypeError, re.escape(msg)):
+            copy.replace(node, unknown=1)
+        # assert that there is no side-effect
+        self.assertIs(node.id, 'x')
+        self.assertIs(node.ctx, context)
+        self.assertRaises(AttributeError, getattr, node, 'unknown')
 
 class ASTHelpers_Test(unittest.TestCase):
     maxDiff = None
diff --git a/Misc/NEWS.d/next/Library/2024-06-29-15-21-12.gh-issue-121141.4evD6q.rst b/Misc/NEWS.d/next/Library/2024-06-29-15-21-12.gh-issue-121141.4evD6q.rst
new file mode 100644 (file)
index 0000000..f2dc621
--- /dev/null
@@ -0,0 +1 @@
+Add support for :func:`copy.replace` to AST nodes. Patch by Bénédikt Tran.
index e338656a5b1eb9647380edf9885f83985c6ba958..f3667801782f2b7cea8a95e1412eac261b005d6d 100755 (executable)
@@ -1132,6 +1132,279 @@ cleanup:
     return result;
 }
 
+/*
+ * Perform the following validations:
+ *
+ *   - All keyword arguments are known 'fields' or 'attributes'.
+ *   - No field or attribute would be left unfilled after copy.replace().
+ *
+ * On success, this returns 1. Otherwise, set a TypeError
+ * exception and returns -1 (no exception is set if some
+ * other internal errors occur).
+ *
+ * Parameters
+ *
+ *      self          The AST node instance.
+ *      dict          The AST node instance dictionary (self.__dict__).
+ *      fields        The list of fields (self._fields).
+ *      attributes    The list of attributes (self._attributes).
+ *      kwargs        Keyword arguments passed to ast_type_replace().
+ *
+ * The 'dict', 'fields', 'attributes' and 'kwargs' arguments can be NULL.
+ *
+ * Note: this function can be removed in 3.15 since the verification
+ *       will be done inside the constructor.
+ */
+static inline int
+ast_type_replace_check(PyObject *self,
+                       PyObject *dict,
+                       PyObject *fields,
+                       PyObject *attributes,
+                       PyObject *kwargs)
+{
+    // While it is possible to make some fast paths that would avoid
+    // allocating objects on the stack, this would cost us readability.
+    // For instance, if 'fields' and 'attributes' are both empty, and
+    // 'kwargs' is not empty, we could raise a TypeError immediately.
+    PyObject *expecting = PySet_New(fields);
+    if (expecting == NULL) {
+        return -1;
+    }
+    if (attributes) {
+        if (_PySet_Update(expecting, attributes) < 0) {
+            Py_DECREF(expecting);
+            return -1;
+        }
+    }
+    // Any keyword argument that is neither a field nor attribute is rejected.
+    // We first need to check whether a keyword argument is accepted or not.
+    // If all keyword arguments are accepted, we compute the required fields
+    // and attributes. A field or attribute is not needed if:
+    //
+    //  1) it is given in 'kwargs', or
+    //  2) it already exists on 'self'.
+    if (kwargs) {
+        Py_ssize_t pos = 0;
+        PyObject *key, *value;
+        while (PyDict_Next(kwargs, &pos, &key, &value)) {
+            int rc = PySet_Discard(expecting, key);
+            if (rc < 0) {
+                Py_DECREF(expecting);
+                return -1;
+            }
+            if (rc == 0) {
+                PyErr_Format(PyExc_TypeError,
+                             "%.400s.__replace__ got an unexpected keyword "
+                             "argument '%U'.", Py_TYPE(self)->tp_name, key);
+                Py_DECREF(expecting);
+                return -1;
+            }
+        }
+    }
+    // check that the remaining fields or attributes would be filled
+    if (dict) {
+        Py_ssize_t pos = 0;
+        PyObject *key, *value;
+        while (PyDict_Next(dict, &pos, &key, &value)) {
+            // Mark fields or attributes that are found on the instance
+            // as non-mandatory. If they are not given in 'kwargs', they
+            // will be shallow-coied; otherwise, they would be replaced
+            // (not in this function).
+            if (PySet_Discard(expecting, key) < 0) {
+                Py_DECREF(expecting);
+                return -1;
+            }
+        }
+        if (attributes) {
+            // Some attributes may or may not be present at runtime.
+            // In particular, now that we checked whether 'kwargs'
+            // is correct or not, we allow any attribute to be missing.
+            //
+            // Note that fields must still be entirely determined when
+            // calling the constructor later.
+            PyObject *unused = PyObject_CallMethodOneArg(expecting,
+                                                         &_Py_ID(difference_update),
+                                                         attributes);
+            if (unused == NULL) {
+                Py_DECREF(expecting);
+                return -1;
+            }
+            Py_DECREF(unused);
+        }
+    }
+    // Now 'expecting' contains the fields or attributes
+    // that would not be filled inside ast_type_replace().
+    Py_ssize_t m = PySet_GET_SIZE(expecting);
+    if (m > 0) {
+        PyObject *names = PyList_New(m);
+        if (names == NULL) {
+            Py_DECREF(expecting);
+            return -1;
+        }
+        Py_ssize_t i = 0, pos = 0;
+        PyObject *item;
+        Py_hash_t hash;
+        while (_PySet_NextEntry(expecting, &pos, &item, &hash)) {
+            PyObject *name = PyObject_Repr(item);
+            if (name == NULL) {
+                Py_DECREF(expecting);
+                Py_DECREF(names);
+                return -1;
+            }
+            // steal the reference 'name'
+            PyList_SET_ITEM(names, i++, name);
+        }
+        Py_DECREF(expecting);
+        if (PyList_Sort(names) < 0) {
+            Py_DECREF(names);
+            return -1;
+        }
+        PyObject *sep = PyUnicode_FromString(", ");
+        if (sep == NULL) {
+            Py_DECREF(names);
+            return -1;
+        }
+        PyObject *str_names = PyUnicode_Join(sep, names);
+        Py_DECREF(sep);
+        Py_DECREF(names);
+        if (str_names == NULL) {
+            return -1;
+        }
+        PyErr_Format(PyExc_TypeError,
+                     "%.400s.__replace__ missing %ld keyword argument%s: %U.",
+                     Py_TYPE(self)->tp_name, m, m == 1 ? "" : "s", str_names);
+        Py_DECREF(str_names);
+        return -1;
+    }
+    else {
+        Py_DECREF(expecting);
+        return 1;
+    }
+}
+
+/*
+ * Python equivalent:
+ *
+ *   for key in keys:
+ *       if hasattr(self, key):
+ *           payload[key] = getattr(self, key)
+ *
+ * The 'keys' argument is a sequence corresponding to
+ * the '_fields' or the '_attributes' of an AST node.
+ *
+ * This returns -1 if an error occurs and 0 otherwise.
+ *
+ * Parameters
+ *
+ *      payload   A dictionary to fill.
+ *      keys      A sequence of keys or NULL for an empty sequence.
+ *      dict      The AST node instance dictionary (must not be NULL).
+ */
+static inline int
+ast_type_replace_update_payload(PyObject *payload,
+                                PyObject *keys,
+                                PyObject *dict)
+{
+    assert(dict != NULL);
+    if (keys == NULL) {
+        return 0;
+    }
+    Py_ssize_t n = PySequence_Size(keys);
+    if (n == -1) {
+        return -1;
+    }
+    for (Py_ssize_t i = 0; i < n; i++) {
+        PyObject *key = PySequence_GetItem(keys, i);
+        if (key == NULL) {
+            return -1;
+        }
+        PyObject *value;
+        if (PyDict_GetItemRef(dict, key, &value) < 0) {
+            Py_DECREF(key);
+            return -1;
+        }
+        if (value == NULL) {
+            Py_DECREF(key);
+            // If a field or attribute is not present at runtime, it should
+            // be explicitly given in 'kwargs'. If not, the constructor will
+            // issue a warning (which becomes an error in 3.15).
+            continue;
+        }
+        int rc = PyDict_SetItem(payload, key, value);
+        Py_DECREF(key);
+        Py_DECREF(value);
+        if (rc < 0) {
+            return -1;
+        }
+    }
+    return 0;
+}
+
+/* copy.replace() support (shallow copy) */
+static PyObject *
+ast_type_replace(PyObject *self, PyObject *args, PyObject *kwargs)
+{
+    if (!_PyArg_NoPositional("__replace__", args)) {
+        return NULL;
+    }
+
+    struct ast_state *state = get_ast_state();
+    if (state == NULL) {
+        return NULL;
+    }
+
+    PyObject *result = NULL;
+    // known AST class fields and attributes
+    PyObject *fields = NULL, *attributes = NULL;
+    // current instance dictionary
+    PyObject *dict = NULL;
+    // constructor positional and keyword arguments
+    PyObject *empty_tuple = NULL, *payload = NULL;
+
+    PyObject *type = (PyObject *)Py_TYPE(self);
+    if (PyObject_GetOptionalAttr(type, state->_fields, &fields) < 0) {
+        goto cleanup;
+    }
+    if (PyObject_GetOptionalAttr(type, state->_attributes, &attributes) < 0) {
+        goto cleanup;
+    }
+    if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) {
+        goto cleanup;
+    }
+    if (ast_type_replace_check(self, dict, fields, attributes, kwargs) < 0) {
+        goto cleanup;
+    }
+    empty_tuple = PyTuple_New(0);
+    if (empty_tuple == NULL) {
+        goto cleanup;
+    }
+    payload = PyDict_New();
+    if (payload == NULL) {
+        goto cleanup;
+    }
+    if (dict) { // in case __dict__ is missing (for some obscure reason)
+        // copy the instance's fields (possibly NULL)
+        if (ast_type_replace_update_payload(payload, fields, dict) < 0) {
+            goto cleanup;
+        }
+        // copy the instance's attributes (possibly NULL)
+        if (ast_type_replace_update_payload(payload, attributes, dict) < 0) {
+            goto cleanup;
+        }
+    }
+    if (kwargs && PyDict_Update(payload, kwargs) < 0) {
+        goto cleanup;
+    }
+    result = PyObject_Call(type, empty_tuple, payload);
+cleanup:
+    Py_XDECREF(payload);
+    Py_XDECREF(empty_tuple);
+    Py_XDECREF(dict);
+    Py_XDECREF(attributes);
+    Py_XDECREF(fields);
+    return result;
+}
+
 static PyMemberDef ast_type_members[] = {
     {"__dictoffset__", Py_T_PYSSIZET, offsetof(AST_object, dict), Py_READONLY},
     {NULL}  /* Sentinel */
@@ -1139,6 +1412,10 @@ static PyMemberDef ast_type_members[] = {
 
 static PyMethodDef ast_type_methods[] = {
     {"__reduce__", ast_type_reduce, METH_NOARGS, NULL},
+    {"__replace__", _PyCFunction_CAST(ast_type_replace), METH_VARARGS | METH_KEYWORDS,
+     PyDoc_STR("__replace__($self, /, **fields)\\n--\\n\\n"
+               "Return a copy of the AST node with new values "
+               "for the specified fields.")},
     {NULL}
 };
 
@@ -1773,7 +2050,9 @@ def generate_module_def(mod, metadata, f, internal_h):
         #include "pycore_ceval.h"         // _Py_EnterRecursiveCall
         #include "pycore_lock.h"          // _PyOnceFlag
         #include "pycore_interp.h"        // _PyInterpreterState.ast
+        #include "pycore_modsupport.h"    // _PyArg_NoPositional()
         #include "pycore_pystate.h"       // _PyInterpreterState_GET()
+        #include "pycore_setobject.h"     // _PySet_NextEntry(), _PySet_Update()
         #include "pycore_unionobject.h"   // _Py_union_type_or
         #include "structmember.h"
         #include <stddef.h>
index 01ffea1869350b523c52ea5464d832898b2975e0..cca2ee409e797871920315496d9361966a8dd11b 100644 (file)
@@ -6,7 +6,9 @@
 #include "pycore_ceval.h"         // _Py_EnterRecursiveCall
 #include "pycore_lock.h"          // _PyOnceFlag
 #include "pycore_interp.h"        // _PyInterpreterState.ast
+#include "pycore_modsupport.h"    // _PyArg_NoPositional()
 #include "pycore_pystate.h"       // _PyInterpreterState_GET()
+#include "pycore_setobject.h"     // _PySet_NextEntry(), _PySet_Update()
 #include "pycore_unionobject.h"   // _Py_union_type_or
 #include "structmember.h"
 #include <stddef.h>
@@ -5331,6 +5333,279 @@ cleanup:
     return result;
 }
 
+/*
+ * Perform the following validations:
+ *
+ *   - All keyword arguments are known 'fields' or 'attributes'.
+ *   - No field or attribute would be left unfilled after copy.replace().
+ *
+ * On success, this returns 1. Otherwise, set a TypeError
+ * exception and returns -1 (no exception is set if some
+ * other internal errors occur).
+ *
+ * Parameters
+ *
+ *      self          The AST node instance.
+ *      dict          The AST node instance dictionary (self.__dict__).
+ *      fields        The list of fields (self._fields).
+ *      attributes    The list of attributes (self._attributes).
+ *      kwargs        Keyword arguments passed to ast_type_replace().
+ *
+ * The 'dict', 'fields', 'attributes' and 'kwargs' arguments can be NULL.
+ *
+ * Note: this function can be removed in 3.15 since the verification
+ *       will be done inside the constructor.
+ */
+static inline int
+ast_type_replace_check(PyObject *self,
+                       PyObject *dict,
+                       PyObject *fields,
+                       PyObject *attributes,
+                       PyObject *kwargs)
+{
+    // While it is possible to make some fast paths that would avoid
+    // allocating objects on the stack, this would cost us readability.
+    // For instance, if 'fields' and 'attributes' are both empty, and
+    // 'kwargs' is not empty, we could raise a TypeError immediately.
+    PyObject *expecting = PySet_New(fields);
+    if (expecting == NULL) {
+        return -1;
+    }
+    if (attributes) {
+        if (_PySet_Update(expecting, attributes) < 0) {
+            Py_DECREF(expecting);
+            return -1;
+        }
+    }
+    // Any keyword argument that is neither a field nor attribute is rejected.
+    // We first need to check whether a keyword argument is accepted or not.
+    // If all keyword arguments are accepted, we compute the required fields
+    // and attributes. A field or attribute is not needed if:
+    //
+    //  1) it is given in 'kwargs', or
+    //  2) it already exists on 'self'.
+    if (kwargs) {
+        Py_ssize_t pos = 0;
+        PyObject *key, *value;
+        while (PyDict_Next(kwargs, &pos, &key, &value)) {
+            int rc = PySet_Discard(expecting, key);
+            if (rc < 0) {
+                Py_DECREF(expecting);
+                return -1;
+            }
+            if (rc == 0) {
+                PyErr_Format(PyExc_TypeError,
+                             "%.400s.__replace__ got an unexpected keyword "
+                             "argument '%U'.", Py_TYPE(self)->tp_name, key);
+                Py_DECREF(expecting);
+                return -1;
+            }
+        }
+    }
+    // check that the remaining fields or attributes would be filled
+    if (dict) {
+        Py_ssize_t pos = 0;
+        PyObject *key, *value;
+        while (PyDict_Next(dict, &pos, &key, &value)) {
+            // Mark fields or attributes that are found on the instance
+            // as non-mandatory. If they are not given in 'kwargs', they
+            // will be shallow-coied; otherwise, they would be replaced
+            // (not in this function).
+            if (PySet_Discard(expecting, key) < 0) {
+                Py_DECREF(expecting);
+                return -1;
+            }
+        }
+        if (attributes) {
+            // Some attributes may or may not be present at runtime.
+            // In particular, now that we checked whether 'kwargs'
+            // is correct or not, we allow any attribute to be missing.
+            //
+            // Note that fields must still be entirely determined when
+            // calling the constructor later.
+            PyObject *unused = PyObject_CallMethodOneArg(expecting,
+                                                         &_Py_ID(difference_update),
+                                                         attributes);
+            if (unused == NULL) {
+                Py_DECREF(expecting);
+                return -1;
+            }
+            Py_DECREF(unused);
+        }
+    }
+    // Now 'expecting' contains the fields or attributes
+    // that would not be filled inside ast_type_replace().
+    Py_ssize_t m = PySet_GET_SIZE(expecting);
+    if (m > 0) {
+        PyObject *names = PyList_New(m);
+        if (names == NULL) {
+            Py_DECREF(expecting);
+            return -1;
+        }
+        Py_ssize_t i = 0, pos = 0;
+        PyObject *item;
+        Py_hash_t hash;
+        while (_PySet_NextEntry(expecting, &pos, &item, &hash)) {
+            PyObject *name = PyObject_Repr(item);
+            if (name == NULL) {
+                Py_DECREF(expecting);
+                Py_DECREF(names);
+                return -1;
+            }
+            // steal the reference 'name'
+            PyList_SET_ITEM(names, i++, name);
+        }
+        Py_DECREF(expecting);
+        if (PyList_Sort(names) < 0) {
+            Py_DECREF(names);
+            return -1;
+        }
+        PyObject *sep = PyUnicode_FromString(", ");
+        if (sep == NULL) {
+            Py_DECREF(names);
+            return -1;
+        }
+        PyObject *str_names = PyUnicode_Join(sep, names);
+        Py_DECREF(sep);
+        Py_DECREF(names);
+        if (str_names == NULL) {
+            return -1;
+        }
+        PyErr_Format(PyExc_TypeError,
+                     "%.400s.__replace__ missing %ld keyword argument%s: %U.",
+                     Py_TYPE(self)->tp_name, m, m == 1 ? "" : "s", str_names);
+        Py_DECREF(str_names);
+        return -1;
+    }
+    else {
+        Py_DECREF(expecting);
+        return 1;
+    }
+}
+
+/*
+ * Python equivalent:
+ *
+ *   for key in keys:
+ *       if hasattr(self, key):
+ *           payload[key] = getattr(self, key)
+ *
+ * The 'keys' argument is a sequence corresponding to
+ * the '_fields' or the '_attributes' of an AST node.
+ *
+ * This returns -1 if an error occurs and 0 otherwise.
+ *
+ * Parameters
+ *
+ *      payload   A dictionary to fill.
+ *      keys      A sequence of keys or NULL for an empty sequence.
+ *      dict      The AST node instance dictionary (must not be NULL).
+ */
+static inline int
+ast_type_replace_update_payload(PyObject *payload,
+                                PyObject *keys,
+                                PyObject *dict)
+{
+    assert(dict != NULL);
+    if (keys == NULL) {
+        return 0;
+    }
+    Py_ssize_t n = PySequence_Size(keys);
+    if (n == -1) {
+        return -1;
+    }
+    for (Py_ssize_t i = 0; i < n; i++) {
+        PyObject *key = PySequence_GetItem(keys, i);
+        if (key == NULL) {
+            return -1;
+        }
+        PyObject *value;
+        if (PyDict_GetItemRef(dict, key, &value) < 0) {
+            Py_DECREF(key);
+            return -1;
+        }
+        if (value == NULL) {
+            Py_DECREF(key);
+            // If a field or attribute is not present at runtime, it should
+            // be explicitly given in 'kwargs'. If not, the constructor will
+            // issue a warning (which becomes an error in 3.15).
+            continue;
+        }
+        int rc = PyDict_SetItem(payload, key, value);
+        Py_DECREF(key);
+        Py_DECREF(value);
+        if (rc < 0) {
+            return -1;
+        }
+    }
+    return 0;
+}
+
+/* copy.replace() support (shallow copy) */
+static PyObject *
+ast_type_replace(PyObject *self, PyObject *args, PyObject *kwargs)
+{
+    if (!_PyArg_NoPositional("__replace__", args)) {
+        return NULL;
+    }
+
+    struct ast_state *state = get_ast_state();
+    if (state == NULL) {
+        return NULL;
+    }
+
+    PyObject *result = NULL;
+    // known AST class fields and attributes
+    PyObject *fields = NULL, *attributes = NULL;
+    // current instance dictionary
+    PyObject *dict = NULL;
+    // constructor positional and keyword arguments
+    PyObject *empty_tuple = NULL, *payload = NULL;
+
+    PyObject *type = (PyObject *)Py_TYPE(self);
+    if (PyObject_GetOptionalAttr(type, state->_fields, &fields) < 0) {
+        goto cleanup;
+    }
+    if (PyObject_GetOptionalAttr(type, state->_attributes, &attributes) < 0) {
+        goto cleanup;
+    }
+    if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) {
+        goto cleanup;
+    }
+    if (ast_type_replace_check(self, dict, fields, attributes, kwargs) < 0) {
+        goto cleanup;
+    }
+    empty_tuple = PyTuple_New(0);
+    if (empty_tuple == NULL) {
+        goto cleanup;
+    }
+    payload = PyDict_New();
+    if (payload == NULL) {
+        goto cleanup;
+    }
+    if (dict) { // in case __dict__ is missing (for some obscure reason)
+        // copy the instance's fields (possibly NULL)
+        if (ast_type_replace_update_payload(payload, fields, dict) < 0) {
+            goto cleanup;
+        }
+        // copy the instance's attributes (possibly NULL)
+        if (ast_type_replace_update_payload(payload, attributes, dict) < 0) {
+            goto cleanup;
+        }
+    }
+    if (kwargs && PyDict_Update(payload, kwargs) < 0) {
+        goto cleanup;
+    }
+    result = PyObject_Call(type, empty_tuple, payload);
+cleanup:
+    Py_XDECREF(payload);
+    Py_XDECREF(empty_tuple);
+    Py_XDECREF(dict);
+    Py_XDECREF(attributes);
+    Py_XDECREF(fields);
+    return result;
+}
+
 static PyMemberDef ast_type_members[] = {
     {"__dictoffset__", Py_T_PYSSIZET, offsetof(AST_object, dict), Py_READONLY},
     {NULL}  /* Sentinel */
@@ -5338,6 +5613,10 @@ static PyMemberDef ast_type_members[] = {
 
 static PyMethodDef ast_type_methods[] = {
     {"__reduce__", ast_type_reduce, METH_NOARGS, NULL},
+    {"__replace__", _PyCFunction_CAST(ast_type_replace), METH_VARARGS | METH_KEYWORDS,
+     PyDoc_STR("__replace__($self, /, **fields)\n--\n\n"
+               "Return a copy of the AST node with new values "
+               "for the specified fields.")},
     {NULL}
 };