]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-94601: [Enum] fix inheritance for __str__ and friends (GH-94942)
authorEthan Furman <ethan@stoneleaf.us>
Mon, 18 Jul 2022 01:51:04 +0000 (18:51 -0700)
committerGitHub <noreply@github.com>
Mon, 18 Jul 2022 01:51:04 +0000 (18:51 -0700)
Lib/enum.py
Lib/test/test_enum.py

index a4f1f09adae01c561f700395ab7bed7d90e6899c..80945c116bfe581327eb027008b9f74e8fe5553d 100644 (file)
@@ -247,7 +247,10 @@ class _proto_member:
         if not enum_class._use_args_:
             enum_member = enum_class._new_member_(enum_class)
             if not hasattr(enum_member, '_value_'):
-                enum_member._value_ = value
+                try:
+                    enum_member._value_ = enum_class._member_type_(*args)
+                except Exception as exc:
+                    enum_member._value_ = value
         else:
             enum_member = enum_class._new_member_(enum_class, *args)
             if not hasattr(enum_member, '_value_'):
@@ -562,7 +565,13 @@ class EnumType(type):
                 classdict['__str__'] = enum_class.__str__
         for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
             if name not in classdict:
-                setattr(enum_class, name, getattr(first_enum, name))
+                # check for mixin overrides before replacing
+                enum_method = getattr(first_enum, name)
+                found_method = getattr(enum_class, name)
+                object_method = getattr(object, name)
+                data_type_method = getattr(member_type, name)
+                if found_method in (data_type_method, object_method):
+                    setattr(enum_class, name, enum_method)
         #
         # for Flag, add __or__, __and__, __xor__, and __invert__
         if Flag is not None and issubclass(enum_class, Flag):
@@ -937,16 +946,18 @@ class EnumType(type):
     @classmethod
     def _find_data_type_(mcls, class_name, bases):
         data_types = set()
+        base_chain = set()
         for chain in bases:
             candidate = None
             for base in chain.__mro__:
+                base_chain.add(base)
                 if base is object:
                     continue
                 elif issubclass(base, Enum):
                     if base._member_type_ is not object:
                         data_types.add(base._member_type_)
                         break
-                elif '__new__' in base.__dict__:
+                elif '__new__' in base.__dict__ or '__init__' in base.__dict__:
                     if issubclass(base, Enum):
                         continue
                     data_types.add(candidate or base)
@@ -1658,7 +1669,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
         enum_class = type(cls_name, (etype, ), body, boundary=boundary, _simple=True)
         for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
             if name not in body:
-                setattr(enum_class, name, getattr(etype, name))
+                # check for mixin overrides before replacing
+                enum_method = getattr(etype, name)
+                found_method = getattr(enum_class, name)
+                object_method = getattr(object, name)
+                data_type_method = getattr(member_type, name)
+                if found_method in (data_type_method, object_method):
+                    setattr(enum_class, name, enum_method)
         gnv_last_values = []
         if issubclass(enum_class, Flag):
             # Flag / IntFlag
@@ -1989,7 +2006,6 @@ def _old_convert_(etype, name, module, filter, source=None, *, boundary=None):
         members.sort(key=lambda t: t[0])
     cls = etype(name, members, module=module, boundary=boundary or KEEP)
     cls.__reduce_ex__ = _reduce_ex_by_global_name
-    cls.__repr__ = global_enum_repr
     return cls
 
 _stdlib_enums = IntEnum, StrEnum, IntFlag
index 87d7c725014e3eacbeb71c45b2035f76cb056045..69fba9a13c89a4b9f4bc53fa66441241d8f26c4e 100644 (file)
@@ -2693,12 +2693,15 @@ class TestSpecial(unittest.TestCase):
         @dataclass
         class Foo:
             __qualname__ = 'Foo'
-            a: int = 0
+            a: int
         class Entries(Foo, Enum):
-            ENTRY1 = Foo(1)
+            ENTRY1 = 1
+        self.assertTrue(isinstance(Entries.ENTRY1, Foo))
+        self.assertTrue(Entries._member_type_ is Foo, Entries._member_type_)
+        self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value)
         self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
 
-    def test_repr_with_non_data_type_mixin(self):
+    def test_repr_with_init_data_type_mixin(self):
         # non-data_type is a mixin that doesn't define __new__
         class Foo:
             def __init__(self, a):
@@ -2706,10 +2709,23 @@ class TestSpecial(unittest.TestCase):
             def __repr__(self):
                 return f'Foo(a={self.a!r})'
         class Entries(Foo, Enum):
-            ENTRY1 = Foo(1)
-
+            ENTRY1 = 1
+        #
         self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
 
+    def test_repr_and_str_with_non_data_type_mixin(self):
+        # non-data_type is a mixin that doesn't define __new__
+        class Foo:
+            def __repr__(self):
+                return 'Foo'
+            def __str__(self):
+                return 'ooF'
+        class Entries(Foo, Enum):
+            ENTRY1 = 1
+        #
+        self.assertEqual(repr(Entries.ENTRY1), 'Foo')
+        self.assertEqual(str(Entries.ENTRY1), 'ooF')
+
     def test_value_backup_assign(self):
         # check that enum will add missing values when custom __new__ does not
         class Some(Enum):