]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-108751: Add copy.replace() function (GH-108752)
authorSerhiy Storchaka <storchaka@gmail.com>
Wed, 6 Sep 2023 20:55:42 +0000 (23:55 +0300)
committerGitHub <noreply@github.com>
Wed, 6 Sep 2023 20:55:42 +0000 (23:55 +0300)
It creates a modified copy of an object by calling the object's
__replace__() method.

It is a generalization of dataclasses.replace(), named tuple's _replace()
method and replace() methods in various classes, and supports all these
stdlib classes.

19 files changed:
Doc/library/collections.rst
Doc/library/copy.rst
Doc/library/dataclasses.rst
Doc/library/datetime.rst
Doc/library/inspect.rst
Doc/library/types.rst
Doc/whatsnew/3.13.rst
Lib/_pydatetime.py
Lib/collections/__init__.py
Lib/copy.py
Lib/dataclasses.py
Lib/inspect.py
Lib/test/datetimetester.py
Lib/test/test_code.py
Lib/test/test_copy.py
Lib/test/test_inspect.py
Misc/NEWS.d/next/Library/2023-09-01-13-14-08.gh-issue-108751.2itqwe.rst [new file with mode: 0644]
Modules/_datetimemodule.c
Objects/codeobject.c

index b8b231bb15b1b0fade5198b59cad81a332e1da5a..03cb1dca8f816c53c6fbadd70d8969784615418b 100644 (file)
@@ -979,6 +979,8 @@ field names, the method and attribute names start with an underscore.
         >>> for partnum, record in inventory.items():
         ...     inventory[partnum] = record._replace(price=newprices[partnum], timestamp=time.now())
 
+    Named tuples are also supported by generic function :func:`copy.replace`.
+
 .. attribute:: somenamedtuple._fields
 
     Tuple of strings listing the field names.  Useful for introspection
index 8f32477ed508c372ba57e3b73d255df80206c4ae..cc4ca034d07a00b7b88c3430ebe1eaa4b83e66aa 100644 (file)
@@ -17,14 +17,22 @@ operations (explained below).
 
 Interface summary:
 
-.. function:: copy(x)
+.. function:: copy(obj)
 
-   Return a shallow copy of *x*.
+   Return a shallow copy of *obj*.
 
 
-.. function:: deepcopy(x[, memo])
+.. function:: deepcopy(obj[, memo])
 
-   Return a deep copy of *x*.
+   Return a deep copy of *obj*.
+
+
+.. function:: replace(obj, /, **changes)
+
+   Creates a new object of the same type as *obj*, replacing fields with values
+   from *changes*.
+
+   .. versionadded:: 3.13
 
 
 .. exception:: Error
@@ -89,6 +97,20 @@ with the component as first argument and the memo dictionary as second argument.
 The memo dictionary should be treated as an opaque object.
 
 
+.. index::
+   single: __replace__() (replace protocol)
+
+Function :func:`replace` is more limited than :func:`copy` and :func:`deepcopy`,
+and only supports named tuples created by :func:`~collections.namedtuple`,
+:mod:`dataclasses`, and other classes which define method :meth:`!__replace__`.
+
+   .. method:: __replace__(self, /, **changes)
+      :noindex:
+
+:meth:`!__replace__` should create a new object of the same type,
+replacing fields with values from *changes*.
+
+
 .. seealso::
 
    Module :mod:`pickle`
index d68748767c5e377d1090d4524ccea11073529238..d78a6071a50e4bef0c3668b4919bd722727ed56b 100644 (file)
@@ -456,6 +456,8 @@ Module contents
    ``replace()`` (or similarly named) method which handles instance
    copying.
 
+   Dataclass instances are also supported by generic function :func:`copy.replace`.
+
 .. function:: is_dataclass(obj)
 
    Return ``True`` if its parameter is a dataclass or an instance of one,
index 04cc75562937e0ee49e3ba5263faec2cbae4c628..0b9d42f32e3bd603b2f331c98b6249327a9dd3bd 100644 (file)
@@ -652,6 +652,9 @@ Instance methods:
        >>> d.replace(day=26)
        datetime.date(2002, 12, 26)
 
+   :class:`date` objects are also supported by generic function
+   :func:`copy.replace`.
+
 
 .. method:: date.timetuple()
 
@@ -1251,6 +1254,9 @@ Instance methods:
    ``tzinfo=None`` can be specified to create a naive datetime from an aware
    datetime with no conversion of date and time data.
 
+   :class:`datetime` objects are also supported by generic function
+   :func:`copy.replace`.
+
    .. versionadded:: 3.6
       Added the ``fold`` argument.
 
@@ -1827,6 +1833,9 @@ Instance methods:
    ``tzinfo=None`` can be specified to create a naive :class:`.time` from an
    aware :class:`.time`, without conversion of the time data.
 
+   :class:`time` objects are also supported by generic function
+   :func:`copy.replace`.
+
    .. versionadded:: 3.6
       Added the ``fold`` argument.
 
index 603ac3263bb9ec043e8c3b1270a42790e49a4855..fe0ed135029f0f9649a093d1b95a1dbbe36a2b9e 100644 (file)
@@ -689,8 +689,8 @@ function.
    The optional *return_annotation* argument, can be an arbitrary Python object,
    is the "return" annotation of the callable.
 
-   Signature objects are *immutable*.  Use :meth:`Signature.replace` to make a
-   modified copy.
+   Signature objects are *immutable*.  Use :meth:`Signature.replace` or
+   :func:`copy.replace` to make a modified copy.
 
    .. versionchanged:: 3.5
       Signature objects are picklable and :term:`hashable`.
@@ -746,6 +746,9 @@ function.
          >>> str(new_sig)
          "(a, b) -> 'new return anno'"
 
+      Signature objects are also supported by generic function
+      :func:`copy.replace`.
+
    .. classmethod:: Signature.from_callable(obj, *, follow_wrapped=True, globalns=None, localns=None)
 
        Return a :class:`Signature` (or its subclass) object for a given callable
@@ -769,7 +772,7 @@ function.
 .. class:: Parameter(name, kind, *, default=Parameter.empty, annotation=Parameter.empty)
 
    Parameter objects are *immutable*.  Instead of modifying a Parameter object,
-   you can use :meth:`Parameter.replace` to create a modified copy.
+   you can use :meth:`Parameter.replace` or :func:`copy.replace` to create a modified copy.
 
    .. versionchanged:: 3.5
       Parameter objects are picklable and :term:`hashable`.
@@ -892,6 +895,8 @@ function.
          >>> str(param.replace(default=Parameter.empty, annotation='spam'))
          "foo:'spam'"
 
+      Parameter objects are also supported by generic function :func:`copy.replace`.
+
    .. versionchanged:: 3.4
       In Python 3.3 Parameter objects were allowed to have ``name`` set
       to ``None`` if their ``kind`` was set to ``POSITIONAL_ONLY``.
index 8cbe17df16f1078a1abab636913cf188bf7fccf5..82300afef0641e60cfb78a8b74b112c28829fc9c 100644 (file)
@@ -200,6 +200,8 @@ Standard names are defined for the following types:
 
      Return a copy of the code object with new values for the specified fields.
 
+     Code objects are also supported by generic function :func:`copy.replace`.
+
      .. versionadded:: 3.8
 
 .. data:: CellType
index de23172ac7a43b1510ea9872fa811f18b5982e90..8c6467562aeb62249623816d124356699eb77836 100644 (file)
@@ -115,6 +115,18 @@ array
   It can be used instead of ``'u'`` type code, which is deprecated.
   (Contributed by Inada Naoki in :gh:`80480`.)
 
+copy
+----
+
+* Add :func:`copy.replace` function which allows to create a modified copy of
+  an object, which is especially usefule for immutable objects.
+  It supports named tuples created with the factory function
+  :func:`collections.namedtuple`, :class:`~dataclasses.dataclass` instances,
+  various :mod:`datetime` objects, :class:`~inspect.Signature` objects,
+  :class:`~inspect.Parameter` objects, :ref:`code object <code-objects>`, and
+  any user classes which define the :meth:`!__replace__` method.
+  (Contributed by Serhiy Storchaka in :gh:`108751`.)
+
 dbm
 ---
 
index 549fcda19dccf249ea7443ef937fa26ae3374d7f..df616bbaf8388d27e5877dbb885063831a0502a6 100644 (file)
@@ -1112,6 +1112,8 @@ class date:
             day = self._day
         return type(self)(year, month, day)
 
+    __replace__ = replace
+
     # Comparisons of date objects with other.
 
     def __eq__(self, other):
@@ -1637,6 +1639,8 @@ class time:
             fold = self._fold
         return type(self)(hour, minute, second, microsecond, tzinfo, fold=fold)
 
+    __replace__ = replace
+
     # Pickle support.
 
     def _getstate(self, protocol=3):
@@ -1983,6 +1987,8 @@ class datetime(date):
         return type(self)(year, month, day, hour, minute, second,
                           microsecond, tzinfo, fold=fold)
 
+    __replace__ = replace
+
     def _local_timezone(self):
         if self.tzinfo is None:
             ts = self._mktime()
index 8652dc8a4ec4501ef6f2f5eb111c3aa4d541cf8d..a461550ea40da74da9c112d77b8a83f5f45bcbef 100644 (file)
@@ -495,6 +495,7 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
         '_field_defaults': field_defaults,
         '__new__': __new__,
         '_make': _make,
+        '__replace__': _replace,
         '_replace': _replace,
         '__repr__': __repr__,
         '_asdict': _asdict,
index da2908ef623d8c9cfc7a10fbf491d69c914b116e..6d7bb9a111b5b48f41c475e4d9718472995b835b 100644 (file)
@@ -290,3 +290,16 @@ def _reconstruct(x, memo, func, args,
     return y
 
 del types, weakref
+
+
+def replace(obj, /, **changes):
+    """Return a new object replacing specified fields with new values.
+
+    This is especially useful for immutable objects, like named tuples or
+    frozen dataclasses.
+    """
+    cls = obj.__class__
+    func = getattr(cls, '__replace__', None)
+    if func is None:
+        raise TypeError(f"replace() does not support {cls.__name__} objects")
+    return func(obj, **changes)
index 21f3fa5c213f1f6a38adf2dc4edc4821bdfce975..84f8d68ce092a4502c54db77058fc8efd19643aa 100644 (file)
@@ -1073,6 +1073,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
                                     globals,
                                     slots,
                           ))
+    _set_new_attribute(cls, '__replace__', _replace)
 
     # Get the fields as a list, and include only real fields.  This is
     # used in all of the following methods.
@@ -1546,13 +1547,15 @@ def replace(obj, /, **changes):
       c1 = replace(c, x=3)
       assert c1.x == 3 and c1.y == 2
     """
+    if not _is_dataclass_instance(obj):
+        raise TypeError("replace() should be called on dataclass instances")
+    return _replace(obj, **changes)
+
 
+def _replace(obj, /, **changes):
     # We're going to mutate 'changes', but that's okay because it's a
     # new dict, even if called with 'replace(obj, **my_changes)'.
 
-    if not _is_dataclass_instance(obj):
-        raise TypeError("replace() should be called on dataclass instances")
-
     # It's an error to have init=False fields in 'changes'.
     # If a field is not in 'changes', read its value from the provided obj.
 
index c8211833dd0831962f1f69e92b6b9efdc1a63d1f..aaa22bef8966028fb8d01bf1c576eec188fa6a58 100644 (file)
@@ -2870,6 +2870,8 @@ class Parameter:
 
         return formatted
 
+    __replace__ = replace
+
     def __repr__(self):
         return '<{} "{}">'.format(self.__class__.__name__, self)
 
@@ -3130,6 +3132,8 @@ class Signature:
         return type(self)(parameters,
                           return_annotation=return_annotation)
 
+    __replace__ = replace
+
     def _hash_basis(self):
         params = tuple(param for param in self.parameters.values()
                              if param.kind != _KEYWORD_ONLY)
index 55e061950ff2806bb7038abddb5f773b12cd45ae..8bda17358db87f93831288d3a01087ff630b974b 100644 (file)
@@ -1699,22 +1699,23 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
         cls = self.theclass
         args = [1, 2, 3]
         base = cls(*args)
-        self.assertEqual(base, base.replace())
+        self.assertEqual(base.replace(), base)
+        self.assertEqual(copy.replace(base), base)
 
-        i = 0
-        for name, newval in (("year", 2),
-                             ("month", 3),
-                             ("day", 4)):
+        changes = (("year", 2),
+                   ("month", 3),
+                   ("day", 4))
+        for i, (name, newval) in enumerate(changes):
             newargs = args[:]
             newargs[i] = newval
             expected = cls(*newargs)
-            got = base.replace(**{name: newval})
-            self.assertEqual(expected, got)
-            i += 1
+            self.assertEqual(base.replace(**{name: newval}), expected)
+            self.assertEqual(copy.replace(base, **{name: newval}), expected)
 
         # Out of bounds.
         base = cls(2000, 2, 29)
         self.assertRaises(ValueError, base.replace, year=2001)
+        self.assertRaises(ValueError, copy.replace, base, year=2001)
 
     def test_subclass_replace(self):
         class DateSubclass(self.theclass):
@@ -1722,6 +1723,7 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
 
         dt = DateSubclass(2012, 1, 1)
         self.assertIs(type(dt.replace(year=2013)), DateSubclass)
+        self.assertIs(type(copy.replace(dt, year=2013)), DateSubclass)
 
     def test_subclass_date(self):
 
@@ -2856,26 +2858,27 @@ class TestDateTime(TestDate):
         cls = self.theclass
         args = [1, 2, 3, 4, 5, 6, 7]
         base = cls(*args)
-        self.assertEqual(base, base.replace())
-
-        i = 0
-        for name, newval in (("year", 2),
-                             ("month", 3),
-                             ("day", 4),
-                             ("hour", 5),
-                             ("minute", 6),
-                             ("second", 7),
-                             ("microsecond", 8)):
+        self.assertEqual(base.replace(), base)
+        self.assertEqual(copy.replace(base), base)
+
+        changes = (("year", 2),
+                   ("month", 3),
+                   ("day", 4),
+                   ("hour", 5),
+                   ("minute", 6),
+                   ("second", 7),
+                   ("microsecond", 8))
+        for i, (name, newval) in enumerate(changes):
             newargs = args[:]
             newargs[i] = newval
             expected = cls(*newargs)
-            got = base.replace(**{name: newval})
-            self.assertEqual(expected, got)
-            i += 1
+            self.assertEqual(base.replace(**{name: newval}), expected)
+            self.assertEqual(copy.replace(base, **{name: newval}), expected)
 
         # Out of bounds.
         base = cls(2000, 2, 29)
         self.assertRaises(ValueError, base.replace, year=2001)
+        self.assertRaises(ValueError, copy.replace, base, year=2001)
 
     @support.run_with_tz('EDT4')
     def test_astimezone(self):
@@ -3671,19 +3674,19 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
         cls = self.theclass
         args = [1, 2, 3, 4]
         base = cls(*args)
-        self.assertEqual(base, base.replace())
-
-        i = 0
-        for name, newval in (("hour", 5),
-                             ("minute", 6),
-                             ("second", 7),
-                             ("microsecond", 8)):
+        self.assertEqual(base.replace(), base)
+        self.assertEqual(copy.replace(base), base)
+
+        changes = (("hour", 5),
+                   ("minute", 6),
+                   ("second", 7),
+                   ("microsecond", 8))
+        for i, (name, newval) in enumerate(changes):
             newargs = args[:]
             newargs[i] = newval
             expected = cls(*newargs)
-            got = base.replace(**{name: newval})
-            self.assertEqual(expected, got)
-            i += 1
+            self.assertEqual(base.replace(**{name: newval}), expected)
+            self.assertEqual(copy.replace(base, **{name: newval}), expected)
 
         # Out of bounds.
         base = cls(1)
@@ -3691,6 +3694,10 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
         self.assertRaises(ValueError, base.replace, minute=-1)
         self.assertRaises(ValueError, base.replace, second=100)
         self.assertRaises(ValueError, base.replace, microsecond=1000000)
+        self.assertRaises(ValueError, copy.replace, base, hour=24)
+        self.assertRaises(ValueError, copy.replace, base, minute=-1)
+        self.assertRaises(ValueError, copy.replace, base, second=100)
+        self.assertRaises(ValueError, copy.replace, base, microsecond=1000000)
 
     def test_subclass_replace(self):
         class TimeSubclass(self.theclass):
@@ -3698,6 +3705,7 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
 
         ctime = TimeSubclass(12, 30)
         self.assertIs(type(ctime.replace(hour=10)), TimeSubclass)
+        self.assertIs(type(copy.replace(ctime, hour=10)), TimeSubclass)
 
     def test_subclass_time(self):
 
@@ -4085,31 +4093,37 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
         zm200 = FixedOffset(timedelta(minutes=-200), "-200")
         args = [1, 2, 3, 4, z100]
         base = cls(*args)
-        self.assertEqual(base, base.replace())
-
-        i = 0
-        for name, newval in (("hour", 5),
-                             ("minute", 6),
-                             ("second", 7),
-                             ("microsecond", 8),
-                             ("tzinfo", zm200)):
+        self.assertEqual(base.replace(), base)
+        self.assertEqual(copy.replace(base), base)
+
+        changes = (("hour", 5),
+                   ("minute", 6),
+                   ("second", 7),
+                   ("microsecond", 8),
+                   ("tzinfo", zm200))
+        for i, (name, newval) in enumerate(changes):
             newargs = args[:]
             newargs[i] = newval
             expected = cls(*newargs)
-            got = base.replace(**{name: newval})
-            self.assertEqual(expected, got)
-            i += 1
+            self.assertEqual(base.replace(**{name: newval}), expected)
+            self.assertEqual(copy.replace(base, **{name: newval}), expected)
 
         # Ensure we can get rid of a tzinfo.
         self.assertEqual(base.tzname(), "+100")
         base2 = base.replace(tzinfo=None)
         self.assertIsNone(base2.tzinfo)
         self.assertIsNone(base2.tzname())
+        base22 = copy.replace(base, tzinfo=None)
+        self.assertIsNone(base22.tzinfo)
+        self.assertIsNone(base22.tzname())
 
         # Ensure we can add one.
         base3 = base2.replace(tzinfo=z100)
         self.assertEqual(base, base3)
         self.assertIs(base.tzinfo, base3.tzinfo)
+        base32 = copy.replace(base22, tzinfo=z100)
+        self.assertEqual(base, base32)
+        self.assertIs(base.tzinfo, base32.tzinfo)
 
         # Out of bounds.
         base = cls(1)
@@ -4117,6 +4131,10 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
         self.assertRaises(ValueError, base.replace, minute=-1)
         self.assertRaises(ValueError, base.replace, second=100)
         self.assertRaises(ValueError, base.replace, microsecond=1000000)
+        self.assertRaises(ValueError, copy.replace, base, hour=24)
+        self.assertRaises(ValueError, copy.replace, base, minute=-1)
+        self.assertRaises(ValueError, copy.replace, base, second=100)
+        self.assertRaises(ValueError, copy.replace, base, microsecond=1000000)
 
     def test_mixed_compare(self):
         t1 = self.theclass(1, 2, 3)
@@ -4885,38 +4903,45 @@ class TestDateTimeTZ(TestDateTime, TZInfoBase, unittest.TestCase):
         zm200 = FixedOffset(timedelta(minutes=-200), "-200")
         args = [1, 2, 3, 4, 5, 6, 7, z100]
         base = cls(*args)
-        self.assertEqual(base, base.replace())
-
-        i = 0
-        for name, newval in (("year", 2),
-                             ("month", 3),
-                             ("day", 4),
-                             ("hour", 5),
-                             ("minute", 6),
-                             ("second", 7),
-                             ("microsecond", 8),
-                             ("tzinfo", zm200)):
+        self.assertEqual(base.replace(), base)
+        self.assertEqual(copy.replace(base), base)
+
+        changes = (("year", 2),
+                   ("month", 3),
+                   ("day", 4),
+                   ("hour", 5),
+                   ("minute", 6),
+                   ("second", 7),
+                   ("microsecond", 8),
+                   ("tzinfo", zm200))
+        for i, (name, newval) in enumerate(changes):
             newargs = args[:]
             newargs[i] = newval
             expected = cls(*newargs)
-            got = base.replace(**{name: newval})
-            self.assertEqual(expected, got)
-            i += 1
+            self.assertEqual(base.replace(**{name: newval}), expected)
+            self.assertEqual(copy.replace(base, **{name: newval}), expected)
 
         # Ensure we can get rid of a tzinfo.
         self.assertEqual(base.tzname(), "+100")
         base2 = base.replace(tzinfo=None)
         self.assertIsNone(base2.tzinfo)
         self.assertIsNone(base2.tzname())
+        base22 = copy.replace(base, tzinfo=None)
+        self.assertIsNone(base22.tzinfo)
+        self.assertIsNone(base22.tzname())
 
         # Ensure we can add one.
         base3 = base2.replace(tzinfo=z100)
         self.assertEqual(base, base3)
         self.assertIs(base.tzinfo, base3.tzinfo)
+        base32 = copy.replace(base22, tzinfo=z100)
+        self.assertEqual(base, base32)
+        self.assertIs(base.tzinfo, base32.tzinfo)
 
         # Out of bounds.
         base = cls(2000, 2, 29)
         self.assertRaises(ValueError, base.replace, year=2001)
+        self.assertRaises(ValueError, copy.replace, base, year=2001)
 
     def test_more_astimezone(self):
         # The inherited test_astimezone covered some trivial and error cases.
index e056c16466e8c47aeb7f89ec1b7c05e1a790dd3d..812c0682569207645726d8f7f9b25dbfd6c3e44b 100644 (file)
@@ -125,6 +125,7 @@ consts: ('None',)
 
 """
 
+import copy
 import inspect
 import sys
 import threading
@@ -280,11 +281,17 @@ class CodeTest(unittest.TestCase):
             with self.subTest(attr=attr, value=value):
                 new_code = code.replace(**{attr: value})
                 self.assertEqual(getattr(new_code, attr), value)
+                new_code = copy.replace(code, **{attr: value})
+                self.assertEqual(getattr(new_code, attr), value)
 
         new_code = code.replace(co_varnames=code2.co_varnames,
                                 co_nlocals=code2.co_nlocals)
         self.assertEqual(new_code.co_varnames, code2.co_varnames)
         self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
+        new_code = copy.replace(code, co_varnames=code2.co_varnames,
+                                co_nlocals=code2.co_nlocals)
+        self.assertEqual(new_code.co_varnames, code2.co_varnames)
+        self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
 
     def test_nlocals_mismatch(self):
         def func():
index 826e46824e004ccafd698d13bf6d08b0c2cd0fda..c66c6eeb00811e65f781e9c8081136b9eb4fa177 100644 (file)
@@ -4,7 +4,7 @@ import copy
 import copyreg
 import weakref
 import abc
-from operator import le, lt, ge, gt, eq, ne
+from operator import le, lt, ge, gt, eq, ne, attrgetter
 
 import unittest
 from test import support
@@ -899,7 +899,71 @@ class TestCopy(unittest.TestCase):
         g.b()
 
 
+class TestReplace(unittest.TestCase):
+
+    def test_unsupported(self):
+        self.assertRaises(TypeError, copy.replace, 1)
+        self.assertRaises(TypeError, copy.replace, [])
+        self.assertRaises(TypeError, copy.replace, {})
+        def f(): pass
+        self.assertRaises(TypeError, copy.replace, f)
+        class A: pass
+        self.assertRaises(TypeError, copy.replace, A)
+        self.assertRaises(TypeError, copy.replace, A())
+
+    def test_replace_method(self):
+        class A:
+            def __new__(cls, x, y=0):
+                self = object.__new__(cls)
+                self.x = x
+                self.y = y
+                return self
+
+            def __init__(self, *args, **kwargs):
+                self.z = self.x + self.y
+
+            def __replace__(self, **changes):
+                x = changes.get('x', self.x)
+                y = changes.get('y', self.y)
+                return type(self)(x, y)
+
+        attrs = attrgetter('x', 'y', 'z')
+        a = A(11, 22)
+        self.assertEqual(attrs(copy.replace(a)), (11, 22, 33))
+        self.assertEqual(attrs(copy.replace(a, x=1)), (1, 22, 23))
+        self.assertEqual(attrs(copy.replace(a, y=2)), (11, 2, 13))
+        self.assertEqual(attrs(copy.replace(a, x=1, y=2)), (1, 2, 3))
+
+    def test_namedtuple(self):
+        from collections import namedtuple
+        Point = namedtuple('Point', 'x y', defaults=(0,))
+        p = Point(11, 22)
+        self.assertEqual(copy.replace(p), (11, 22))
+        self.assertEqual(copy.replace(p, x=1), (1, 22))
+        self.assertEqual(copy.replace(p, y=2), (11, 2))
+        self.assertEqual(copy.replace(p, x=1, y=2), (1, 2))
+        with self.assertRaisesRegex(ValueError, 'unexpected field name'):
+            copy.replace(p, x=1, error=2)
+
+    def test_dataclass(self):
+        from dataclasses import dataclass
+        @dataclass
+        class C:
+            x: int
+            y: int = 0
+
+        attrs = attrgetter('x', 'y')
+        c = C(11, 22)
+        self.assertEqual(attrs(copy.replace(c)), (11, 22))
+        self.assertEqual(attrs(copy.replace(c, x=1)), (1, 22))
+        self.assertEqual(attrs(copy.replace(c, y=2)), (11, 2))
+        self.assertEqual(attrs(copy.replace(c, x=1, y=2)), (1, 2))
+        with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
+            copy.replace(c, x=1, error=2)
+
+
 def global_foo(x, y): return x+y
 
+
 if __name__ == "__main__":
     unittest.main()
index 78ef817906b2aaf49d1efdb10c4e7e106c58f7ed..2fb356a0529ab06179668f5da2d9158b84651a6e 100644 (file)
@@ -1,6 +1,7 @@
 import asyncio
 import builtins
 import collections
+import copy
 import datetime
 import functools
 import importlib
@@ -3830,6 +3831,28 @@ class TestSignatureObject(unittest.TestCase):
                                 P('bar', P.VAR_POSITIONAL)])),
                          '(foo, /, *bar)')
 
+    def test_signature_replace_parameters(self):
+        def test(a, b) -> 42:
+            pass
+
+        sig = inspect.signature(test)
+        parameters = sig.parameters
+        sig = sig.replace(parameters=list(parameters.values())[1:])
+        self.assertEqual(list(sig.parameters), ['b'])
+        self.assertEqual(sig.parameters['b'], parameters['b'])
+        self.assertEqual(sig.return_annotation, 42)
+        sig = sig.replace(parameters=())
+        self.assertEqual(dict(sig.parameters), {})
+
+        sig = inspect.signature(test)
+        parameters = sig.parameters
+        sig = copy.replace(sig, parameters=list(parameters.values())[1:])
+        self.assertEqual(list(sig.parameters), ['b'])
+        self.assertEqual(sig.parameters['b'], parameters['b'])
+        self.assertEqual(sig.return_annotation, 42)
+        sig = copy.replace(sig, parameters=())
+        self.assertEqual(dict(sig.parameters), {})
+
     def test_signature_replace_anno(self):
         def test() -> 42:
             pass
@@ -3843,6 +3866,15 @@ class TestSignatureObject(unittest.TestCase):
         self.assertEqual(sig.return_annotation, 42)
         self.assertEqual(sig, inspect.signature(test))
 
+        sig = inspect.signature(test)
+        sig = copy.replace(sig, return_annotation=None)
+        self.assertIs(sig.return_annotation, None)
+        sig = copy.replace(sig, return_annotation=sig.empty)
+        self.assertIs(sig.return_annotation, sig.empty)
+        sig = copy.replace(sig, return_annotation=42)
+        self.assertEqual(sig.return_annotation, 42)
+        self.assertEqual(sig, inspect.signature(test))
+
     def test_signature_replaced(self):
         def test():
             pass
@@ -4187,41 +4219,66 @@ class TestParameterObject(unittest.TestCase):
         p = inspect.Parameter('foo', default=42,
                               kind=inspect.Parameter.KEYWORD_ONLY)
 
-        self.assertIsNot(p, p.replace())
-        self.assertEqual(p, p.replace())
+        self.assertIsNot(p.replace(), p)
+        self.assertEqual(p.replace(), p)
+        self.assertIsNot(copy.replace(p), p)
+        self.assertEqual(copy.replace(p), p)
 
         p2 = p.replace(annotation=1)
         self.assertEqual(p2.annotation, 1)
         p2 = p2.replace(annotation=p2.empty)
-        self.assertEqual(p, p2)
+        self.assertEqual(p2, p)
+        p3 = copy.replace(p, annotation=1)
+        self.assertEqual(p3.annotation, 1)
+        p3 = copy.replace(p3, annotation=p3.empty)
+        self.assertEqual(p3, p)
 
         p2 = p2.replace(name='bar')
         self.assertEqual(p2.name, 'bar')
         self.assertNotEqual(p2, p)
+        p3 = copy.replace(p3, name='bar')
+        self.assertEqual(p3.name, 'bar')
+        self.assertNotEqual(p3, p)
 
         with self.assertRaisesRegex(ValueError,
                                     'name is a required attribute'):
             p2 = p2.replace(name=p2.empty)
+        with self.assertRaisesRegex(ValueError,
+                                    'name is a required attribute'):
+            p3 = copy.replace(p3, name=p3.empty)
 
         p2 = p2.replace(name='foo', default=None)
         self.assertIs(p2.default, None)
         self.assertNotEqual(p2, p)
+        p3 = copy.replace(p3, name='foo', default=None)
+        self.assertIs(p3.default, None)
+        self.assertNotEqual(p3, p)
 
         p2 = p2.replace(name='foo', default=p2.empty)
         self.assertIs(p2.default, p2.empty)
-
+        p3 = copy.replace(p3, name='foo', default=p3.empty)
+        self.assertIs(p3.default, p3.empty)
 
         p2 = p2.replace(default=42, kind=p2.POSITIONAL_OR_KEYWORD)
         self.assertEqual(p2.kind, p2.POSITIONAL_OR_KEYWORD)
         self.assertNotEqual(p2, p)
+        p3 = copy.replace(p3, default=42, kind=p3.POSITIONAL_OR_KEYWORD)
+        self.assertEqual(p3.kind, p3.POSITIONAL_OR_KEYWORD)
+        self.assertNotEqual(p3, p)
 
         with self.assertRaisesRegex(ValueError,
                                     "value <class 'inspect._empty'> "
                                     "is not a valid Parameter.kind"):
             p2 = p2.replace(kind=p2.empty)
+        with self.assertRaisesRegex(ValueError,
+                                    "value <class 'inspect._empty'> "
+                                    "is not a valid Parameter.kind"):
+            p3 = copy.replace(p3, kind=p3.empty)
 
         p2 = p2.replace(kind=p2.KEYWORD_ONLY)
         self.assertEqual(p2, p)
+        p3 = copy.replace(p3, kind=p3.KEYWORD_ONLY)
+        self.assertEqual(p3, p)
 
     def test_signature_parameter_positional_only(self):
         with self.assertRaisesRegex(TypeError, 'name must be a str'):
diff --git a/Misc/NEWS.d/next/Library/2023-09-01-13-14-08.gh-issue-108751.2itqwe.rst b/Misc/NEWS.d/next/Library/2023-09-01-13-14-08.gh-issue-108751.2itqwe.rst
new file mode 100644 (file)
index 0000000..7bc21fe
--- /dev/null
@@ -0,0 +1,2 @@
+Add :func:`copy.replace` function which allows to create a modified copy of
+an object. It supports named tuples, dataclasses, and many other objects.
index 191db3f84088d5f310cbcbcd5ef560ea9b196067..0d356779cfe1920622f1820ff0008367da6b0766 100644 (file)
@@ -3590,6 +3590,8 @@ static PyMethodDef date_methods[] = {
     {"replace",     _PyCFunction_CAST(date_replace),      METH_VARARGS | METH_KEYWORDS,
      PyDoc_STR("Return date with new specified fields.")},
 
+    {"__replace__", _PyCFunction_CAST(date_replace),      METH_VARARGS | METH_KEYWORDS},
+
     {"__reduce__", (PyCFunction)date_reduce,        METH_NOARGS,
      PyDoc_STR("__reduce__() -> (cls, state)")},
 
@@ -4719,6 +4721,8 @@ static PyMethodDef time_methods[] = {
     {"replace",     _PyCFunction_CAST(time_replace),          METH_VARARGS | METH_KEYWORDS,
      PyDoc_STR("Return time with new specified fields.")},
 
+    {"__replace__", _PyCFunction_CAST(time_replace),          METH_VARARGS | METH_KEYWORDS},
+
      {"fromisoformat", (PyCFunction)time_fromisoformat, METH_O | METH_CLASS,
      PyDoc_STR("string -> time from a string in ISO 8601 format")},
 
@@ -6579,6 +6583,8 @@ static PyMethodDef datetime_methods[] = {
     {"replace",     _PyCFunction_CAST(datetime_replace),      METH_VARARGS | METH_KEYWORDS,
      PyDoc_STR("Return datetime with new specified fields.")},
 
+    {"__replace__", _PyCFunction_CAST(datetime_replace),      METH_VARARGS | METH_KEYWORDS},
+
     {"astimezone",  _PyCFunction_CAST(datetime_astimezone), METH_VARARGS | METH_KEYWORDS,
      PyDoc_STR("tz -> convert to local time in new timezone tz\n")},
 
index 70a0c2ebd66b2455b002c78f4214b1108987a47e..58306075cad48b75843d916606f0f55dbf57952d 100644 (file)
@@ -2145,6 +2145,7 @@ static struct PyMethodDef code_methods[] = {
     {"co_positions", (PyCFunction)code_positionsiterator, METH_NOARGS},
     CODE_REPLACE_METHODDEF
     CODE__VARNAME_FROM_OPARG_METHODDEF
+    {"__replace__", _PyCFunction_CAST(code_replace), METH_FASTCALL|METH_KEYWORDS},
     {NULL, NULL}                /* sentinel */
 };