]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-109409: Fix inheritance of frozen dataclass from non-frozen dataclass mixins ...
authorNikita Sobolev <mail@sobolevn.me>
Thu, 12 Oct 2023 13:05:23 +0000 (16:05 +0300)
committerGitHub <noreply@github.com>
Thu, 12 Oct 2023 13:05:23 +0000 (09:05 -0400)
Fix inheritance of frozen dataclass from non-frozen dataclass mixins

Lib/dataclasses.py
Lib/test/test_dataclasses/__init__.py
Misc/NEWS.d/next/Library/2023-09-15-10-42-30.gh-issue-109409.RlffA3.rst [new file with mode: 0644]

index 31dc6f8abce91a7c57d026b5769e1b53fe6510f4..2fba32b5ffbc1e2e37c30e0af00a6d779ad14e39 100644 (file)
@@ -944,8 +944,11 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
     # Find our base classes in reverse MRO order, and exclude
     # ourselves.  In reversed order so that more derived classes
     # override earlier field definitions in base classes.  As long as
-    # we're iterating over them, see if any are frozen.
+    # we're iterating over them, see if all or any of them are frozen.
     any_frozen_base = False
+    # By default `all_frozen_bases` is `None` to represent a case,
+    # where some dataclasses does not have any bases with `_FIELDS`
+    all_frozen_bases = None
     has_dataclass_bases = False
     for b in cls.__mro__[-1:0:-1]:
         # Only process classes that have been processed by our
@@ -955,8 +958,11 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
             has_dataclass_bases = True
             for f in base_fields.values():
                 fields[f.name] = f
-            if getattr(b, _PARAMS).frozen:
-                any_frozen_base = True
+            if all_frozen_bases is None:
+                all_frozen_bases = True
+            current_frozen = getattr(b, _PARAMS).frozen
+            all_frozen_bases = all_frozen_bases and current_frozen
+            any_frozen_base = any_frozen_base or current_frozen
 
     # Annotations defined specifically in this class (not in base classes).
     #
@@ -1025,7 +1031,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
                             'frozen one')
 
         # Raise an exception if we're frozen, but none of our bases are.
-        if not any_frozen_base and frozen:
+        if all_frozen_bases is False and frozen:
             raise TypeError('cannot inherit frozen dataclass from a '
                             'non-frozen one')
 
index f629d7bb53959b271996745b654f5905b542864a..272d427875ae4073114c1a985a16d80f2fcd7f60 100644 (file)
@@ -2863,6 +2863,101 @@ class TestFrozen(unittest.TestCase):
             class D(C):
                 j: int
 
+    def test_inherit_frozen_mutliple_inheritance(self):
+        @dataclass
+        class NotFrozen:
+            pass
+
+        @dataclass(frozen=True)
+        class Frozen:
+            pass
+
+        class NotDataclass:
+            pass
+
+        for bases in (
+            (NotFrozen, Frozen),
+            (Frozen, NotFrozen),
+            (Frozen, NotDataclass),
+            (NotDataclass, Frozen),
+        ):
+            with self.subTest(bases=bases):
+                with self.assertRaisesRegex(
+                    TypeError,
+                    'cannot inherit non-frozen dataclass from a frozen one',
+                ):
+                    @dataclass
+                    class NotFrozenChild(*bases):
+                        pass
+
+        for bases in (
+            (NotFrozen, Frozen),
+            (Frozen, NotFrozen),
+            (NotFrozen, NotDataclass),
+            (NotDataclass, NotFrozen),
+        ):
+            with self.subTest(bases=bases):
+                with self.assertRaisesRegex(
+                    TypeError,
+                    'cannot inherit frozen dataclass from a non-frozen one',
+                ):
+                    @dataclass(frozen=True)
+                    class FrozenChild(*bases):
+                        pass
+
+    def test_inherit_frozen_mutliple_inheritance_regular_mixins(self):
+        @dataclass(frozen=True)
+        class Frozen:
+            pass
+
+        class NotDataclass:
+            pass
+
+        class C1(Frozen, NotDataclass):
+            pass
+        self.assertEqual(C1.__mro__, (C1, Frozen, NotDataclass, object))
+
+        class C2(NotDataclass, Frozen):
+            pass
+        self.assertEqual(C2.__mro__, (C2, NotDataclass, Frozen, object))
+
+        @dataclass(frozen=True)
+        class C3(Frozen, NotDataclass):
+            pass
+        self.assertEqual(C3.__mro__, (C3, Frozen, NotDataclass, object))
+
+        @dataclass(frozen=True)
+        class C4(NotDataclass, Frozen):
+            pass
+        self.assertEqual(C4.__mro__, (C4, NotDataclass, Frozen, object))
+
+    def test_multiple_frozen_dataclasses_inheritance(self):
+        @dataclass(frozen=True)
+        class FrozenA:
+            pass
+
+        @dataclass(frozen=True)
+        class FrozenB:
+            pass
+
+        class C1(FrozenA, FrozenB):
+            pass
+        self.assertEqual(C1.__mro__, (C1, FrozenA, FrozenB, object))
+
+        class C2(FrozenB, FrozenA):
+            pass
+        self.assertEqual(C2.__mro__, (C2, FrozenB, FrozenA, object))
+
+        @dataclass(frozen=True)
+        class C3(FrozenA, FrozenB):
+            pass
+        self.assertEqual(C3.__mro__, (C3, FrozenA, FrozenB, object))
+
+        @dataclass(frozen=True)
+        class C4(FrozenB, FrozenA):
+            pass
+        self.assertEqual(C4.__mro__, (C4, FrozenB, FrozenA, object))
+
     def test_inherit_nonfrozen_from_empty(self):
         @dataclass
         class C:
diff --git a/Misc/NEWS.d/next/Library/2023-09-15-10-42-30.gh-issue-109409.RlffA3.rst b/Misc/NEWS.d/next/Library/2023-09-15-10-42-30.gh-issue-109409.RlffA3.rst
new file mode 100644 (file)
index 0000000..eddad64
--- /dev/null
@@ -0,0 +1,2 @@
+Fix error when it was possible to inherit a frozen dataclass from multiple
+parents some of which were possibly not frozen.