From: Miss Islington (bot) <31488909+miss-islington@users.noreply.github.com> Date: Tue, 20 Feb 2024 00:18:30 +0000 (+0100) Subject: [3.11] gh-115539: Allow enum.Flag to have None members (GH-115636) (GH-115695) X-Git-Tag: v3.11.9~168 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=b3384af46406c171e4e15fc6fb0901efc1eb6cd4;p=thirdparty%2FPython%2Fcpython.git [3.11] gh-115539: Allow enum.Flag to have None members (GH-115636) (GH-115695) gh-115539: Allow enum.Flag to have None members (GH-115636) (cherry picked from commit c2cb31bbe1262213085c425bc853d6587c66cae9) Co-authored-by: Jason Zhang --- diff --git a/Lib/enum.py b/Lib/enum.py index 63833aaa717f..8c705cf392cf 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -276,9 +276,10 @@ class _proto_member: enum_member._sort_order_ = len(enum_class._member_names_) if Flag is not None and issubclass(enum_class, Flag): - enum_class._flag_mask_ |= value - if _is_single_bit(value): - enum_class._singles_mask_ |= value + if isinstance(value, int): + enum_class._flag_mask_ |= value + if _is_single_bit(value): + enum_class._singles_mask_ |= value enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1 # If another member with the same value was already defined, the @@ -306,6 +307,7 @@ class _proto_member: elif ( Flag is not None and issubclass(enum_class, Flag) + and isinstance(value, int) and _is_single_bit(value) ): # no other instances found, record this member in _member_names_ @@ -1502,37 +1504,50 @@ class Flag(Enum, boundary=STRICT): def __bool__(self): return bool(self._value_) + def _get_value(self, flag): + if isinstance(flag, self.__class__): + return flag._value_ + elif self._member_type_ is not object and isinstance(flag, self._member_type_): + return flag + return NotImplemented + def __or__(self, other): - if isinstance(other, self.__class__): - other = other._value_ - elif self._member_type_ is not object and isinstance(other, self._member_type_): - other = other - else: + other_value = self._get_value(other) + if other_value is NotImplemented: return NotImplemented + + for flag in self, other: + if self._get_value(flag) is None: + raise TypeError(f"'{flag}' cannot be combined with other flags with |") value = self._value_ - return self.__class__(value | other) + return self.__class__(value | other_value) def __and__(self, other): - if isinstance(other, self.__class__): - other = other._value_ - elif self._member_type_ is not object and isinstance(other, self._member_type_): - other = other - else: + other_value = self._get_value(other) + if other_value is NotImplemented: return NotImplemented + + for flag in self, other: + if self._get_value(flag) is None: + raise TypeError(f"'{flag}' cannot be combined with other flags with &") value = self._value_ - return self.__class__(value & other) + return self.__class__(value & other_value) def __xor__(self, other): - if isinstance(other, self.__class__): - other = other._value_ - elif self._member_type_ is not object and isinstance(other, self._member_type_): - other = other - else: + other_value = self._get_value(other) + if other_value is NotImplemented: return NotImplemented + + for flag in self, other: + if self._get_value(flag) is None: + raise TypeError(f"'{flag}' cannot be combined with other flags with ^") value = self._value_ - return self.__class__(value ^ other) + return self.__class__(value ^ other_value) def __invert__(self): + if self._get_value(self) is None: + raise TypeError(f"'{self}' cannot be inverted") + if self._inverted_ is None: if self._boundary_ in (EJECT, KEEP): self._inverted_ = self.__class__(~self._value_) diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index f37e2f7e1a6f..e642b84d7216 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -903,6 +903,22 @@ class TestPlainEnum(_EnumTests, _PlainOutputTests, unittest.TestCase): class TestPlainFlag(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase): enum_type = Flag + def test_none_member(self): + class FlagWithNoneMember(Flag): + A = 1 + E = None + + self.assertEqual(FlagWithNoneMember.A.value, 1) + self.assertIs(FlagWithNoneMember.E.value, None) + with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with |"): + FlagWithNoneMember.A | FlagWithNoneMember.E + with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with &"): + FlagWithNoneMember.E & FlagWithNoneMember.A + with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with \^"): + FlagWithNoneMember.A ^ FlagWithNoneMember.E + with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be inverted"): + ~FlagWithNoneMember.E + class TestIntEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase): enum_type = IntEnum