]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-90562: Support zero argument super with dataclasses when slots=True (gh-124455)
authorEric V. Smith <ericvsmith@users.noreply.github.com>
Wed, 25 Sep 2024 01:26:26 +0000 (21:26 -0400)
committerGitHub <noreply@github.com>
Wed, 25 Sep 2024 01:26:26 +0000 (21:26 -0400)
Co-authored-by: @wookie184
Co-authored-by: Carl Meyer <carl@oddbird.net>
Doc/library/dataclasses.rst
Lib/dataclasses.py
Lib/test/test_dataclasses/__init__.py
Misc/NEWS.d/next/Library/2024-09-23-18-26-17.gh-issue-90562.Yj566G.rst [new file with mode: 0644]

index cfca11afbd2e41cf4a07685a28273804889f30a2..1457392ce6e86ca1074cea576f628c5f092ea6d0 100644 (file)
@@ -187,13 +187,6 @@ Module contents
      If :attr:`!__slots__` is already defined in the class, then :exc:`TypeError`
      is raised.
 
-    .. warning::
-        Calling no-arg :func:`super` in dataclasses using ``slots=True``
-        will result in the following exception being raised:
-        ``TypeError: super(type, obj): obj must be an instance or subtype of type``.
-        The two-arg :func:`super` is a valid workaround.
-        See :gh:`90562` for full details.
-
     .. warning::
        Passing parameters to a base class :meth:`~object.__init_subclass__`
        when using ``slots=True`` will result in a :exc:`TypeError`.
index 6255d8980974e0c70a65d73d0b32ba720f74107a..f5cb97edaf72cdb4f090ed5680b33401ee49a1ac 100644 (file)
@@ -1218,9 +1218,31 @@ def _get_slots(cls):
             raise TypeError(f"Slots of '{cls.__name__}' cannot be determined")
 
 
+def _update_func_cell_for__class__(f, oldcls, newcls):
+    # Returns True if we update a cell, else False.
+    if f is None:
+        # f will be None in the case of a property where not all of
+        # fget, fset, and fdel are used.  Nothing to do in that case.
+        return False
+    try:
+        idx = f.__code__.co_freevars.index("__class__")
+    except ValueError:
+        # This function doesn't reference __class__, so nothing to do.
+        return False
+    # Fix the cell to point to the new class, if it's already pointing
+    # at the old class.  I'm not convinced that the "is oldcls" test
+    # is needed, but other than performance can't hurt.
+    closure = f.__closure__[idx]
+    if closure.cell_contents is oldcls:
+        closure.cell_contents = newcls
+        return True
+    return False
+
+
 def _add_slots(cls, is_frozen, weakref_slot):
-    # Need to create a new class, since we can't set __slots__
-    #  after a class has been created.
+    # Need to create a new class, since we can't set __slots__ after a
+    # class has been created, and the @dataclass decorator is called
+    # after the class is created.
 
     # Make sure __slots__ isn't already set.
     if '__slots__' in cls.__dict__:
@@ -1259,18 +1281,37 @@ def _add_slots(cls, is_frozen, weakref_slot):
 
     # And finally create the class.
     qualname = getattr(cls, '__qualname__', None)
-    cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
+    newcls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
     if qualname is not None:
-        cls.__qualname__ = qualname
+        newcls.__qualname__ = qualname
 
     if is_frozen:
         # Need this for pickling frozen classes with slots.
         if '__getstate__' not in cls_dict:
-            cls.__getstate__ = _dataclass_getstate
+            newcls.__getstate__ = _dataclass_getstate
         if '__setstate__' not in cls_dict:
-            cls.__setstate__ = _dataclass_setstate
-
-    return cls
+            newcls.__setstate__ = _dataclass_setstate
+
+    # Fix up any closures which reference __class__.  This is used to
+    # fix zero argument super so that it points to the correct class
+    # (the newly created one, which we're returning) and not the
+    # original class.  We can break out of this loop as soon as we
+    # make an update, since all closures for a class will share a
+    # given cell.
+    for member in newcls.__dict__.values():
+        # If this is a wrapped function, unwrap it.
+        member = inspect.unwrap(member)
+
+        if isinstance(member, types.FunctionType):
+            if _update_func_cell_for__class__(member, cls, newcls):
+                break
+        elif isinstance(member, property):
+            if (_update_func_cell_for__class__(member.fget, cls, newcls)
+                or _update_func_cell_for__class__(member.fset, cls, newcls)
+                or _update_func_cell_for__class__(member.fdel, cls, newcls)):
+                break
+
+    return newcls
 
 
 def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
index 6934e88d9d338cf1ff01bddd840b2ce92488f186..69e86162e0c11a3858feba84a53e39f8cb12dc34 100644 (file)
@@ -17,7 +17,7 @@ from unittest.mock import Mock
 from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict
 from typing import get_type_hints
 from collections import deque, OrderedDict, namedtuple, defaultdict
-from functools import total_ordering
+from functools import total_ordering, wraps
 
 import typing       # Needed for the string "typing.ClassVar[int]" to work as an annotation.
 import dataclasses  # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
@@ -4869,5 +4869,129 @@ class TestKeywordArgs(unittest.TestCase):
         self.assertEqual(fs[0].name, 'x')
 
 
+class TestZeroArgumentSuperWithSlots(unittest.TestCase):
+    def test_zero_argument_super(self):
+        @dataclass(slots=True)
+        class A:
+            def foo(self):
+                super()
+
+        A().foo()
+
+    def test_dunder_class_with_old_property(self):
+        @dataclass(slots=True)
+        class A:
+            def _get_foo(slf):
+                self.assertIs(__class__, type(slf))
+                self.assertIs(__class__, slf.__class__)
+                return __class__
+
+            def _set_foo(slf, value):
+                self.assertIs(__class__, type(slf))
+                self.assertIs(__class__, slf.__class__)
+
+            def _del_foo(slf):
+                self.assertIs(__class__, type(slf))
+                self.assertIs(__class__, slf.__class__)
+
+            foo = property(_get_foo, _set_foo, _del_foo)
+
+        a = A()
+        self.assertIs(a.foo, A)
+        a.foo = 4
+        del a.foo
+
+    def test_dunder_class_with_new_property(self):
+        @dataclass(slots=True)
+        class A:
+            @property
+            def foo(slf):
+                return slf.__class__
+
+            @foo.setter
+            def foo(slf, value):
+                self.assertIs(__class__, type(slf))
+
+            @foo.deleter
+            def foo(slf):
+                self.assertIs(__class__, type(slf))
+
+        a = A()
+        self.assertIs(a.foo, A)
+        a.foo = 4
+        del a.foo
+
+    # Test the parts of a property individually.
+    def test_slots_dunder_class_property_getter(self):
+        @dataclass(slots=True)
+        class A:
+            @property
+            def foo(slf):
+                return __class__
+
+        a = A()
+        self.assertIs(a.foo, A)
+
+    def test_slots_dunder_class_property_setter(self):
+        @dataclass(slots=True)
+        class A:
+            foo = property()
+            @foo.setter
+            def foo(slf, val):
+                self.assertIs(__class__, type(slf))
+
+        a = A()
+        a.foo = 4
+
+    def test_slots_dunder_class_property_deleter(self):
+        @dataclass(slots=True)
+        class A:
+            foo = property()
+            @foo.deleter
+            def foo(slf):
+                self.assertIs(__class__, type(slf))
+
+        a = A()
+        del a.foo
+
+    def test_wrapped(self):
+        def mydecorator(f):
+            @wraps(f)
+            def wrapper(*args, **kwargs):
+                return f(*args, **kwargs)
+            return wrapper
+
+        @dataclass(slots=True)
+        class A:
+            @mydecorator
+            def foo(self):
+                super()
+
+        A().foo()
+
+    def test_remembered_class(self):
+        # Apply the dataclass decorator manually (not when the class
+        # is created), so that we can keep a reference to the
+        # undecorated class.
+        class A:
+            def cls(self):
+                return __class__
+
+        self.assertIs(A().cls(), A)
+
+        B = dataclass(slots=True)(A)
+        self.assertIs(B().cls(), B)
+
+        # This is undesirable behavior, but is a function of how
+        # modifying __class__ in the closure works.  I'm not sure this
+        # should be tested or not: I don't really want to guarantee
+        # this behavior, but I don't want to lose the point that this
+        # is how it works.
+
+        # The underlying class is "broken" by changing its __class__
+        # in A.foo() to B.  This normally isn't a problem, because no
+        # one will be keeping a reference to the underlying class A.
+        self.assertIs(A().cls(), B)
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2024-09-23-18-26-17.gh-issue-90562.Yj566G.rst b/Misc/NEWS.d/next/Library/2024-09-23-18-26-17.gh-issue-90562.Yj566G.rst
new file mode 100644 (file)
index 0000000..7a389fe
--- /dev/null
@@ -0,0 +1,3 @@
+Modify dataclasses to support zero-argument super() when ``slots=True`` is
+specified.  This works by modifying all references to ``__class__`` to point
+to the newly created class.