]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-94601: [Enum] fix inheritance for __str__ and friends (GH-94942)
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Mon, 18 Jul 2022 02:18:41 +0000 (19:18 -0700)
committerGitHub <noreply@github.com>
Mon, 18 Jul 2022 02:18:41 +0000 (19:18 -0700)
(cherry picked from commit c961d14f85a0e3e53d5ad1182206ef34030f10b8)

Co-authored-by: Ethan Furman <ethan@stoneleaf.us>
Lib/enum.py
Lib/test/test_enum.py

index b19d40cbc5ed9b57ca9c3624a9dd8720879445b3..f5c29edffcf9fc1a157189919011c5e2f8958134 100644 (file)
@@ -247,7 +247,10 @@ class _proto_member:
         if not enum_class._use_args_:
             enum_member = enum_class._new_member_(enum_class)
             if not hasattr(enum_member, '_value_'):
-                enum_member._value_ = value
+                try:
+                    enum_member._value_ = enum_class._member_type_(*args)
+                except Exception as exc:
+                    enum_member._value_ = value
         else:
             enum_member = enum_class._new_member_(enum_class, *args)
             if not hasattr(enum_member, '_value_'):
@@ -562,7 +565,13 @@ class EnumType(type):
                 classdict['__str__'] = enum_class.__str__
         for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
             if name not in classdict:
-                setattr(enum_class, name, getattr(first_enum, name))
+                # check for mixin overrides before replacing
+                enum_method = getattr(first_enum, name)
+                found_method = getattr(enum_class, name)
+                object_method = getattr(object, name)
+                data_type_method = getattr(member_type, name)
+                if found_method in (data_type_method, object_method):
+                    setattr(enum_class, name, enum_method)
         #
         # for Flag, add __or__, __and__, __xor__, and __invert__
         if Flag is not None and issubclass(enum_class, Flag):
@@ -950,16 +959,18 @@ class EnumType(type):
     @classmethod
     def _find_data_type_(mcls, class_name, bases):
         data_types = set()
+        base_chain = set()
         for chain in bases:
             candidate = None
             for base in chain.__mro__:
+                base_chain.add(base)
                 if base is object:
                     continue
                 elif issubclass(base, Enum):
                     if base._member_type_ is not object:
                         data_types.add(base._member_type_)
                         break
-                elif '__new__' in base.__dict__:
+                elif '__new__' in base.__dict__ or '__init__' in base.__dict__:
                     if issubclass(base, Enum):
                         continue
                     data_types.add(candidate or base)
@@ -1671,7 +1682,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
         enum_class = type(cls_name, (etype, ), body, boundary=boundary, _simple=True)
         for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
             if name not in body:
-                setattr(enum_class, name, getattr(etype, name))
+                # check for mixin overrides before replacing
+                enum_method = getattr(etype, name)
+                found_method = getattr(enum_class, name)
+                object_method = getattr(object, name)
+                data_type_method = getattr(member_type, name)
+                if found_method in (data_type_method, object_method):
+                    setattr(enum_class, name, enum_method)
         gnv_last_values = []
         if issubclass(enum_class, Flag):
             # Flag / IntFlag
@@ -2002,7 +2019,6 @@ def _old_convert_(etype, name, module, filter, source=None, *, boundary=None):
         members.sort(key=lambda t: t[0])
     cls = etype(name, members, module=module, boundary=boundary or KEEP)
     cls.__reduce_ex__ = _reduce_ex_by_global_name
-    cls.__repr__ = global_enum_repr
     return cls
 
 _stdlib_enums = IntEnum, StrEnum, IntFlag
index 74f31bec50c4f78de0f15d5d7a0b468d9e167e57..80834f2529ae9e2c90cff280e8cbc148f5b7e74d 100644 (file)
@@ -2658,12 +2658,15 @@ class TestSpecial(unittest.TestCase):
         @dataclass
         class Foo:
             __qualname__ = 'Foo'
-            a: int = 0
+            a: int
         class Entries(Foo, Enum):
-            ENTRY1 = Foo(1)
+            ENTRY1 = 1
+        self.assertTrue(isinstance(Entries.ENTRY1, Foo))
+        self.assertTrue(Entries._member_type_ is Foo, Entries._member_type_)
+        self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value)
         self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
 
-    def test_repr_with_non_data_type_mixin(self):
+    def test_repr_with_init_data_type_mixin(self):
         # non-data_type is a mixin that doesn't define __new__
         class Foo:
             def __init__(self, a):
@@ -2671,10 +2674,23 @@ class TestSpecial(unittest.TestCase):
             def __repr__(self):
                 return f'Foo(a={self.a!r})'
         class Entries(Foo, Enum):
-            ENTRY1 = Foo(1)
-
+            ENTRY1 = 1
+        #
         self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
 
+    def test_repr_and_str_with_non_data_type_mixin(self):
+        # non-data_type is a mixin that doesn't define __new__
+        class Foo:
+            def __repr__(self):
+                return 'Foo'
+            def __str__(self):
+                return 'ooF'
+        class Entries(Foo, Enum):
+            ENTRY1 = 1
+        #
+        self.assertEqual(repr(Entries.ENTRY1), 'Foo')
+        self.assertEqual(str(Entries.ENTRY1), 'ooF')
+
     def test_value_backup_assign(self):
         # check that enum will add missing values when custom __new__ does not
         class Some(Enum):