# Find our base classes in reverse MRO order, and exclude
# ourselves. In reversed order so that more derived classes
# override earlier field definitions in base classes. As long as
- # we're iterating over them, see if any are frozen.
+ # we're iterating over them, see if all or any of them are frozen.
any_frozen_base = False
+ # By default `all_frozen_bases` is `None` to represent a case,
+ # where some dataclasses does not have any bases with `_FIELDS`
+ all_frozen_bases = None
has_dataclass_bases = False
for b in cls.__mro__[-1:0:-1]:
# Only process classes that have been processed by our
has_dataclass_bases = True
for f in base_fields.values():
fields[f.name] = f
- if getattr(b, _PARAMS).frozen:
- any_frozen_base = True
+ if all_frozen_bases is None:
+ all_frozen_bases = True
+ current_frozen = getattr(b, _PARAMS).frozen
+ all_frozen_bases = all_frozen_bases and current_frozen
+ any_frozen_base = any_frozen_base or current_frozen
# Annotations defined specifically in this class (not in base classes).
#
'frozen one')
# Raise an exception if we're frozen, but none of our bases are.
- if not any_frozen_base and frozen:
+ if all_frozen_bases is False and frozen:
raise TypeError('cannot inherit frozen dataclass from a '
'non-frozen one')
class D(C):
j: int
+ def test_inherit_frozen_mutliple_inheritance(self):
+ @dataclass
+ class NotFrozen:
+ pass
+
+ @dataclass(frozen=True)
+ class Frozen:
+ pass
+
+ class NotDataclass:
+ pass
+
+ for bases in (
+ (NotFrozen, Frozen),
+ (Frozen, NotFrozen),
+ (Frozen, NotDataclass),
+ (NotDataclass, Frozen),
+ ):
+ with self.subTest(bases=bases):
+ with self.assertRaisesRegex(
+ TypeError,
+ 'cannot inherit non-frozen dataclass from a frozen one',
+ ):
+ @dataclass
+ class NotFrozenChild(*bases):
+ pass
+
+ for bases in (
+ (NotFrozen, Frozen),
+ (Frozen, NotFrozen),
+ (NotFrozen, NotDataclass),
+ (NotDataclass, NotFrozen),
+ ):
+ with self.subTest(bases=bases):
+ with self.assertRaisesRegex(
+ TypeError,
+ 'cannot inherit frozen dataclass from a non-frozen one',
+ ):
+ @dataclass(frozen=True)
+ class FrozenChild(*bases):
+ pass
+
+ def test_inherit_frozen_mutliple_inheritance_regular_mixins(self):
+ @dataclass(frozen=True)
+ class Frozen:
+ pass
+
+ class NotDataclass:
+ pass
+
+ class C1(Frozen, NotDataclass):
+ pass
+ self.assertEqual(C1.__mro__, (C1, Frozen, NotDataclass, object))
+
+ class C2(NotDataclass, Frozen):
+ pass
+ self.assertEqual(C2.__mro__, (C2, NotDataclass, Frozen, object))
+
+ @dataclass(frozen=True)
+ class C3(Frozen, NotDataclass):
+ pass
+ self.assertEqual(C3.__mro__, (C3, Frozen, NotDataclass, object))
+
+ @dataclass(frozen=True)
+ class C4(NotDataclass, Frozen):
+ pass
+ self.assertEqual(C4.__mro__, (C4, NotDataclass, Frozen, object))
+
+ def test_multiple_frozen_dataclasses_inheritance(self):
+ @dataclass(frozen=True)
+ class FrozenA:
+ pass
+
+ @dataclass(frozen=True)
+ class FrozenB:
+ pass
+
+ class C1(FrozenA, FrozenB):
+ pass
+ self.assertEqual(C1.__mro__, (C1, FrozenA, FrozenB, object))
+
+ class C2(FrozenB, FrozenA):
+ pass
+ self.assertEqual(C2.__mro__, (C2, FrozenB, FrozenA, object))
+
+ @dataclass(frozen=True)
+ class C3(FrozenA, FrozenB):
+ pass
+ self.assertEqual(C3.__mro__, (C3, FrozenA, FrozenB, object))
+
+ @dataclass(frozen=True)
+ class C4(FrozenB, FrozenA):
+ pass
+ self.assertEqual(C4.__mro__, (C4, FrozenB, FrozenA, object))
+
def test_inherit_nonfrozen_from_empty(self):
@dataclass
class C: