]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-125710: [Enum] fix hashable<->nonhashable comparisons for member values (GH-125735)
authorEthan Furman <ethan@stoneleaf.us>
Tue, 22 Oct 2024 18:04:00 +0000 (11:04 -0700)
committerGitHub <noreply@github.com>
Tue, 22 Oct 2024 18:04:00 +0000 (11:04 -0700)
Lib/enum.py
Lib/test/test_enum.py
Misc/NEWS.d/next/Library/2024-10-19-13-37-37.gh-issue-125710.FyFAAr.rst [new file with mode: 0644]

index 17d72738792982675fccfa741a79571602306aab..4f9912229603a60c976e644135d29784f72755cf 100644 (file)
@@ -327,6 +327,8 @@ class _proto_member:
             # to the map, and by-value lookups for this value will be
             # linear.
             enum_class._value2member_map_.setdefault(value, enum_member)
+            if value not in enum_class._hashable_values_:
+                enum_class._hashable_values_.append(value)
         except TypeError:
             # keep track of the value in a list so containment checks are quick
             enum_class._unhashable_values_.append(value)
@@ -538,7 +540,8 @@ class EnumType(type):
         classdict['_member_names_'] = []
         classdict['_member_map_'] = {}
         classdict['_value2member_map_'] = {}
-        classdict['_unhashable_values_'] = []
+        classdict['_hashable_values_'] = []          # for comparing with non-hashable types
+        classdict['_unhashable_values_'] = []       # e.g. frozenset() with set()
         classdict['_unhashable_values_map_'] = {}
         classdict['_member_type_'] = member_type
         # now set the __repr__ for the value
@@ -748,7 +751,10 @@ class EnumType(type):
         try:
             return value in cls._value2member_map_
         except TypeError:
-            return value in cls._unhashable_values_
+            return (
+                    value in cls._unhashable_values_    # both structures are lists
+                    or value in cls._hashable_values_
+                    )
 
     def __delattr__(cls, attr):
         # nicer error message when someone tries to delete an attribute
@@ -1166,8 +1172,11 @@ class Enum(metaclass=EnumType):
             pass
         except TypeError:
             # not there, now do long search -- O(n) behavior
-            for name, values in cls._unhashable_values_map_.items():
-                if value in values:
+            for name, unhashable_values in cls._unhashable_values_map_.items():
+                if value in unhashable_values:
+                    return cls[name]
+            for name, member in cls._member_map_.items():
+                if value == member._value_:
                     return cls[name]
         # still not found -- verify that members exist, in-case somebody got here mistakenly
         # (such as via super when trying to override __new__)
@@ -1233,6 +1242,7 @@ class Enum(metaclass=EnumType):
             # to the map, and by-value lookups for this value will be
             # linear.
             cls._value2member_map_.setdefault(value, self)
+            cls._hashable_values_.append(value)
         except TypeError:
             # keep track of the value in a list so containment checks are quick
             cls._unhashable_values_.append(value)
@@ -1763,6 +1773,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
         body['_member_names_'] = member_names = []
         body['_member_map_'] = member_map = {}
         body['_value2member_map_'] = value2member_map = {}
+        body['_hashable_values_'] = hashable_values = []
         body['_unhashable_values_'] = unhashable_values = []
         body['_unhashable_values_map_'] = {}
         body['_member_type_'] = member_type = etype._member_type_
@@ -1826,7 +1837,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
                     contained = value2member_map.get(member._value_)
                 except TypeError:
                     contained = None
-                    if member._value_ in unhashable_values:
+                    if member._value_ in unhashable_values or member.value in hashable_values:
                         for m in enum_class:
                             if m._value_ == member._value_:
                                 contained = m
@@ -1846,6 +1857,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
                     else:
                         enum_class._add_member_(name, member)
                     value2member_map[value] = member
+                    hashable_values.append(value)
                     if _is_single_bit(value):
                         # not a multi-bit alias, record in _member_names_ and _flag_mask_
                         member_names.append(name)
@@ -1882,7 +1894,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
                     contained = value2member_map.get(member._value_)
                 except TypeError:
                     contained = None
-                    if member._value_ in unhashable_values:
+                    if member._value_ in unhashable_values or member._value_ in hashable_values:
                         for m in enum_class:
                             if m._value_ == member._value_:
                                 contained = m
@@ -1908,6 +1920,8 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
                         # to the map, and by-value lookups for this value will be
                         # linear.
                         enum_class._value2member_map_.setdefault(value, member)
+                        if value not in hashable_values:
+                            hashable_values.append(value)
                     except TypeError:
                         # keep track of the value in a list so containment checks are quick
                         enum_class._unhashable_values_.append(value)
index 5b4a8070526fcf851ea9fe732b62b35fd3ac7f7e..7184769bfd6fc3261a22039730c579609eca71b9 100644 (file)
@@ -3460,6 +3460,13 @@ class TestSpecial(unittest.TestCase):
         self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', names=0)
         self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', 0, type=int)
 
+    def test_nonhashable_matches_hashable(self):    # issue 125710
+        class Directions(Enum):
+            DOWN_ONLY = frozenset({"sc"})
+            UP_ONLY = frozenset({"cs"})
+            UNRESTRICTED = frozenset({"sc", "cs"})
+        self.assertIs(Directions({"sc"}), Directions.DOWN_ONLY)
+
 
 class TestOrder(unittest.TestCase):
     "test usage of the `_order_` attribute"
diff --git a/Misc/NEWS.d/next/Library/2024-10-19-13-37-37.gh-issue-125710.FyFAAr.rst b/Misc/NEWS.d/next/Library/2024-10-19-13-37-37.gh-issue-125710.FyFAAr.rst
new file mode 100644 (file)
index 0000000..8d5220e
--- /dev/null
@@ -0,0 +1 @@
+[Enum] fix hashable<->nonhashable comparisons for member values