]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-118033: Fix `__weakref__` not set for generic dataclasses (#118099)
authorNikita Sobolev <mail@sobolevn.me>
Thu, 9 May 2024 08:36:17 +0000 (11:36 +0300)
committerGitHub <noreply@github.com>
Thu, 9 May 2024 08:36:17 +0000 (11:36 +0300)
Lib/dataclasses.py
Lib/test/test_dataclasses/__init__.py
Misc/NEWS.d/next/Library/2024-04-19-14-59-53.gh-issue-118033.amS4Gw.rst [new file with mode: 0644]

index 3acd03cd86523473172891b791ef52e6286a4af3..aeafbfbbe6e9c4900220c6117033e5eda5fc5f6d 100644 (file)
@@ -1199,10 +1199,17 @@ def _dataclass_setstate(self, state):
 
 def _get_slots(cls):
     match cls.__dict__.get('__slots__'):
-        # A class which does not define __slots__ at all is equivalent
-        # to a class defining __slots__ = ('__dict__', '__weakref__')
+        # `__dictoffset__` and `__weakrefoffset__` can tell us whether
+        # the base type has dict/weakref slots, in a way that works correctly
+        # for both Python classes and C extension types. Extension types
+        # don't use `__slots__` for slot creation
         case None:
-            yield from ('__dict__', '__weakref__')
+            slots = []
+            if getattr(cls, '__weakrefoffset__', -1) != 0:
+                slots.append('__weakref__')
+            if getattr(cls, '__dictrefoffset__', -1) != 0:
+                slots.append('__dict__')
+            yield from slots
         case str(slot):
             yield slot
         # Slots may be any iterable, but we cannot handle an iterator
index 832e5672c77d0dab65819e85596648a2092d7160..ea49596eaa4d969085c4fc91f75646bcdd0d0d81 100644 (file)
@@ -3515,8 +3515,114 @@ class TestSlots(unittest.TestCase):
         class B(A):
             pass
 
+        self.assertEqual(B.__slots__, ())
         B()
 
+    def test_dataclass_derived_generic(self):
+        T = typing.TypeVar('T')
+
+        @dataclass(slots=True, weakref_slot=True)
+        class A(typing.Generic[T]):
+            pass
+        self.assertEqual(A.__slots__, ('__weakref__',))
+        self.assertTrue(A.__weakref__)
+        A()
+
+        @dataclass(slots=True, weakref_slot=True)
+        class B[T2]:
+            pass
+        self.assertEqual(B.__slots__, ('__weakref__',))
+        self.assertTrue(B.__weakref__)
+        B()
+
+    def test_dataclass_derived_generic_from_base(self):
+        T = typing.TypeVar('T')
+
+        class RawBase: ...
+
+        @dataclass(slots=True, weakref_slot=True)
+        class C1(typing.Generic[T], RawBase):
+            pass
+        self.assertEqual(C1.__slots__, ())
+        self.assertTrue(C1.__weakref__)
+        C1()
+        @dataclass(slots=True, weakref_slot=True)
+        class C2(RawBase, typing.Generic[T]):
+            pass
+        self.assertEqual(C2.__slots__, ())
+        self.assertTrue(C2.__weakref__)
+        C2()
+
+        @dataclass(slots=True, weakref_slot=True)
+        class D[T2](RawBase):
+            pass
+        self.assertEqual(D.__slots__, ())
+        self.assertTrue(D.__weakref__)
+        D()
+
+    def test_dataclass_derived_generic_from_slotted_base(self):
+        T = typing.TypeVar('T')
+
+        class WithSlots:
+            __slots__ = ('a', 'b')
+
+        @dataclass(slots=True, weakref_slot=True)
+        class E1(WithSlots, Generic[T]):
+            pass
+        self.assertEqual(E1.__slots__, ('__weakref__',))
+        self.assertTrue(E1.__weakref__)
+        E1()
+        @dataclass(slots=True, weakref_slot=True)
+        class E2(Generic[T], WithSlots):
+            pass
+        self.assertEqual(E2.__slots__, ('__weakref__',))
+        self.assertTrue(E2.__weakref__)
+        E2()
+
+        @dataclass(slots=True, weakref_slot=True)
+        class F[T2](WithSlots):
+            pass
+        self.assertEqual(F.__slots__, ('__weakref__',))
+        self.assertTrue(F.__weakref__)
+        F()
+
+    def test_dataclass_derived_generic_from_slotted_base(self):
+        T = typing.TypeVar('T')
+
+        class WithWeakrefSlot:
+            __slots__ = ('__weakref__',)
+
+        @dataclass(slots=True, weakref_slot=True)
+        class G1(WithWeakrefSlot, Generic[T]):
+            pass
+        self.assertEqual(G1.__slots__, ())
+        self.assertTrue(G1.__weakref__)
+        G1()
+        @dataclass(slots=True, weakref_slot=True)
+        class G2(Generic[T], WithWeakrefSlot):
+            pass
+        self.assertEqual(G2.__slots__, ())
+        self.assertTrue(G2.__weakref__)
+        G2()
+
+        @dataclass(slots=True, weakref_slot=True)
+        class H[T2](WithWeakrefSlot):
+            pass
+        self.assertEqual(H.__slots__, ())
+        self.assertTrue(H.__weakref__)
+        H()
+
+    def test_dataclass_slot_dict(self):
+        class WithDictSlot:
+            __slots__ = ('__dict__',)
+
+        @dataclass(slots=True)
+        class A(WithDictSlot): ...
+
+        self.assertEqual(A.__slots__, ())
+        self.assertEqual(A().__dict__, {})
+        A()
+
 
 class TestDescriptors(unittest.TestCase):
     def test_set_name(self):
diff --git a/Misc/NEWS.d/next/Library/2024-04-19-14-59-53.gh-issue-118033.amS4Gw.rst b/Misc/NEWS.d/next/Library/2024-04-19-14-59-53.gh-issue-118033.amS4Gw.rst
new file mode 100644 (file)
index 0000000..7ceb293
--- /dev/null
@@ -0,0 +1,2 @@
+Fix :func:`dataclasses.dataclass` not creating a ``__weakref__`` slot when
+subclassing :class:`typing.Generic`.