]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-108682: [Enum] raise TypeError if super().__new__ called in custom __new__ (GH...
authorEthan Furman <ethan@stoneleaf.us>
Thu, 31 Aug 2023 19:45:12 +0000 (12:45 -0700)
committerGitHub <noreply@github.com>
Thu, 31 Aug 2023 19:45:12 +0000 (12:45 -0700)
When overriding the `__new__` method of an enum, the underlying data type should be created directly; i.e. .

    member = object.__new__(cls)
    member = int.__new__(cls, value)
    member = str.__new__(cls, value)

Calling `super().__new__()` finds the lookup version of `Enum.__new__`, and will now raise an exception when detected.

Doc/howto/enum.rst
Lib/enum.py
Lib/test/test_enum.py
Misc/NEWS.d/next/Library/2023-08-30-20-10-28.gh-issue-108682.c2gzLQ.rst [new file with mode: 0644]

index 4312b4c8140f5c37524659735b7a3a7ea8010e5e..28749754a54dba29bdb9146e1c2bcf61235db713 100644 (file)
@@ -426,10 +426,17 @@ enumeration, with the exception of special methods (:meth:`__str__`,
 :meth:`__add__`, etc.), descriptors (methods are also descriptors), and
 variable names listed in :attr:`_ignore_`.
 
-Note:  if your enumeration defines :meth:`__new__` and/or :meth:`__init__` then
+Note:  if your enumeration defines :meth:`__new__` and/or :meth:`__init__`,
 any value(s) given to the enum member will be passed into those methods.
 See `Planet`_ for an example.
 
+.. note::
+
+    The :meth:`__new__` method, if defined, is used during creation of the Enum
+    members; it is then replaced by Enum's :meth:`__new__` which is used after
+    class creation for lookup of existing members.  See :ref:`new-vs-init` for
+    more details.
+
 
 Restricted Enum subclassing
 ---------------------------
@@ -895,6 +902,8 @@ Some rules:
    :meth:`__str__` method has been reset to their data types'
    :meth:`__str__` method.
 
+.. _new-vs-init:
+
 When to use :meth:`__new__` vs. :meth:`__init__`
 ------------------------------------------------
 
@@ -927,6 +936,11 @@ want one of them to be the value::
     >>> print(Coordinate(3))
     Coordinate.VY
 
+.. warning::
+
+    *Do not* call ``super().__new__()``, as the lookup-only ``__new__`` is the one
+    that is found; instead, use the data type directly.
+
 
 Finer Points
 ^^^^^^^^^^^^
@@ -1353,6 +1367,13 @@ to handle any extra arguments::
     members; it is then replaced by Enum's :meth:`__new__` which is used after
     class creation for lookup of existing members.
 
+.. warning::
+
+    *Do not* call ``super().__new__()``, as the lookup-only ``__new__`` is the one
+    that is found; instead, use the data type directly -- e.g.::
+
+       obj = int.__new__(cls, value)
+
 
 OrderedEnum
 ^^^^^^^^^^^
index 0c985b2c778569c07a9ea91d11eb48559fc8d97a..4b99e7bda2cca549d3a6e5c009e1912f9fc28038 100644 (file)
@@ -856,6 +856,8 @@ class EnumType(type):
                 value = first_enum._generate_next_value_(name, start, count, last_values[:])
                 last_values.append(value)
                 names.append((name, value))
+        if names is None:
+            names = ()
 
         # Here, names is either an iterable of (name, value) or a mapping.
         for item in names:
@@ -1112,6 +1114,11 @@ class Enum(metaclass=EnumType):
             for member in cls._member_map_.values():
                 if member._value_ == value:
                     return member
+        # still not found -- verify that members exist, in-case somebody got here mistakenly
+        # (such as via super when trying to override __new__)
+        if not cls._member_map_:
+            raise TypeError("%r has no members defined" % cls)
+        #
         # still not found -- try _missing_ hook
         try:
             exc = None
index 36a1ee47640849d464300a4ff78a5f429836ffa5..11a5b425efff9a9db09c9610d9fd819df03c87d1 100644 (file)
@@ -276,11 +276,82 @@ class _EnumTests:
     values = None
 
     def setUp(self):
-        class BaseEnum(self.enum_type):
+        if self.__class__.__name__[-5:] == 'Class':
+            class BaseEnum(self.enum_type):
+                @enum.property
+                def first(self):
+                    return '%s is first!' % self.name
+            class MainEnum(BaseEnum):
+                first = auto()
+                second = auto()
+                third = auto()
+                if issubclass(self.enum_type, Flag):
+                    dupe = 3
+                else:
+                    dupe = third
+            self.MainEnum = MainEnum
+            #
+            class NewStrEnum(self.enum_type):
+                def __str__(self):
+                    return self.name.upper()
+                first = auto()
+            self.NewStrEnum = NewStrEnum
+            #
+            class NewFormatEnum(self.enum_type):
+                def __format__(self, spec):
+                    return self.name.upper()
+                first = auto()
+            self.NewFormatEnum = NewFormatEnum
+            #
+            class NewStrFormatEnum(self.enum_type):
+                def __str__(self):
+                    return self.name.title()
+                def __format__(self, spec):
+                    return ''.join(reversed(self.name))
+                first = auto()
+            self.NewStrFormatEnum = NewStrFormatEnum
+            #
+            class NewBaseEnum(self.enum_type):
+                def __str__(self):
+                    return self.name.title()
+                def __format__(self, spec):
+                    return ''.join(reversed(self.name))
+            class NewSubEnum(NewBaseEnum):
+                first = auto()
+            self.NewSubEnum = NewSubEnum
+            #
+            class LazyGNV(self.enum_type):
+                def _generate_next_value_(name, start, last, values):
+                    pass
+            self.LazyGNV = LazyGNV
+            #
+            class BusyGNV(self.enum_type):
+                @staticmethod
+                def _generate_next_value_(name, start, last, values):
+                    pass
+            self.BusyGNV = BusyGNV
+            #
+            self.is_flag = False
+            self.names = ['first', 'second', 'third']
+            if issubclass(MainEnum, StrEnum):
+                self.values = self.names
+            elif MainEnum._member_type_ is str:
+                self.values = ['1', '2', '3']
+            elif issubclass(self.enum_type, Flag):
+                self.values = [1, 2, 4]
+                self.is_flag = True
+                self.dupe2 = MainEnum(5)
+            else:
+                self.values = self.values or [1, 2, 3]
+            #
+            if not getattr(self, 'source_values', False):
+                self.source_values = self.values
+        elif self.__class__.__name__[-8:] == 'Function':
             @enum.property
             def first(self):
                 return '%s is first!' % self.name
-        class MainEnum(BaseEnum):
+            BaseEnum = self.enum_type('BaseEnum', {'first':first})
+            #
             first = auto()
             second = auto()
             third = auto()
@@ -288,63 +359,60 @@ class _EnumTests:
                 dupe = 3
             else:
                 dupe = third
-        self.MainEnum = MainEnum
-        #
-        class NewStrEnum(self.enum_type):
+            self.MainEnum = MainEnum = BaseEnum('MainEnum', dict(first=first, second=second, third=third, dupe=dupe))
+            #
             def __str__(self):
                 return self.name.upper()
             first = auto()
-        self.NewStrEnum = NewStrEnum
-        #
-        class NewFormatEnum(self.enum_type):
+            self.NewStrEnum = self.enum_type('NewStrEnum', (('first',first),('__str__',__str__)))
+            #
             def __format__(self, spec):
                 return self.name.upper()
             first = auto()
-        self.NewFormatEnum = NewFormatEnum
-        #
-        class NewStrFormatEnum(self.enum_type):
+            self.NewFormatEnum = self.enum_type('NewFormatEnum', [('first',first),('__format__',__format__)])
+            #
             def __str__(self):
                 return self.name.title()
             def __format__(self, spec):
                 return ''.join(reversed(self.name))
             first = auto()
-        self.NewStrFormatEnum = NewStrFormatEnum
-        #
-        class NewBaseEnum(self.enum_type):
+            self.NewStrFormatEnum = self.enum_type('NewStrFormatEnum', dict(first=first, __format__=__format__, __str__=__str__))
+            #
             def __str__(self):
                 return self.name.title()
             def __format__(self, spec):
                 return ''.join(reversed(self.name))
-        class NewSubEnum(NewBaseEnum):
-            first = auto()
-        self.NewSubEnum = NewSubEnum
-        #
-        class LazyGNV(self.enum_type):
+            NewBaseEnum = self.enum_type('NewBaseEnum', dict(__format__=__format__, __str__=__str__))
+            class NewSubEnum(NewBaseEnum):
+                first = auto()
+            self.NewSubEnum = NewBaseEnum('NewSubEnum', 'first')
+            #
             def _generate_next_value_(name, start, last, values):
                 pass
-        self.LazyGNV = LazyGNV
-        #
-        class BusyGNV(self.enum_type):
+            self.LazyGNV = self.enum_type('LazyGNV', {'_generate_next_value_':_generate_next_value_})
+            #
             @staticmethod
             def _generate_next_value_(name, start, last, values):
                 pass
-        self.BusyGNV = BusyGNV
-        #
-        self.is_flag = False
-        self.names = ['first', 'second', 'third']
-        if issubclass(MainEnum, StrEnum):
-            self.values = self.names
-        elif MainEnum._member_type_ is str:
-            self.values = ['1', '2', '3']
-        elif issubclass(self.enum_type, Flag):
-            self.values = [1, 2, 4]
-            self.is_flag = True
-            self.dupe2 = MainEnum(5)
+            self.BusyGNV = self.enum_type('BusyGNV', {'_generate_next_value_':_generate_next_value_})
+            #
+            self.is_flag = False
+            self.names = ['first', 'second', 'third']
+            if issubclass(MainEnum, StrEnum):
+                self.values = self.names
+            elif MainEnum._member_type_ is str:
+                self.values = ['1', '2', '3']
+            elif issubclass(self.enum_type, Flag):
+                self.values = [1, 2, 4]
+                self.is_flag = True
+                self.dupe2 = MainEnum(5)
+            else:
+                self.values = self.values or [1, 2, 3]
+            #
+            if not getattr(self, 'source_values', False):
+                self.source_values = self.values
         else:
-            self.values = self.values or [1, 2, 3]
-        #
-        if not getattr(self, 'source_values', False):
-            self.source_values = self.values
+            raise ValueError('unknown enum style: %r' % self.__class__.__name__)
 
     def assertFormatIsValue(self, spec, member):
         self.assertEqual(spec.format(member), spec.format(member.value))
@@ -372,6 +440,17 @@ class _EnumTests:
         with self.assertRaises(AttributeError):
             del Season.SPRING.name
 
+    def test_bad_new_super(self):
+        with self.assertRaisesRegex(
+                TypeError,
+                'has no members defined',
+            ):
+            class BadSuper(self.enum_type):
+                def __new__(cls, value):
+                    obj = super().__new__(cls, value)
+                    return obj
+                failed = 1
+
     def test_basics(self):
         TE = self.MainEnum
         if self.is_flag:
@@ -427,7 +506,7 @@ class _EnumTests:
         MainEnum = self.MainEnum
         self.assertIn(MainEnum.first, MainEnum)
         self.assertTrue(self.values[0] in MainEnum)
-        if type(self) is not TestStrEnum:
+        if type(self) not in (TestStrEnumClass, TestStrEnumFunction):
             self.assertFalse('first' in MainEnum)
         val = MainEnum.dupe
         self.assertIn(val, MainEnum)
@@ -949,15 +1028,23 @@ class _FlagTests:
             self.assertTrue(~OpenXYZ(0), (X|Y|Z))
 
 
-class TestPlainEnum(_EnumTests, _PlainOutputTests, unittest.TestCase):
+class TestPlainEnumClass(_EnumTests, _PlainOutputTests, unittest.TestCase):
+    enum_type = Enum
+
+
+class TestPlainEnumFunction(_EnumTests, _PlainOutputTests, unittest.TestCase):
     enum_type = Enum
 
 
-class TestPlainFlag(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase):
+class TestPlainFlagClass(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase):
     enum_type = Flag
 
 
-class TestIntEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase):
+class TestPlainFlagFunction(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase):
+    enum_type = Flag
+
+
+class TestIntEnumClass(_EnumTests, _MinimalOutputTests, unittest.TestCase):
     enum_type = IntEnum
     #
     def test_shadowed_attr(self):
@@ -969,7 +1056,17 @@ class TestIntEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase):
         self.assertIs(Number.numerator.divisor, Number.divisor)
 
 
-class TestStrEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase):
+class TestIntEnumFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase):
+    enum_type = IntEnum
+    #
+    def test_shadowed_attr(self):
+        Number = IntEnum('Number', ('divisor', 'numerator'))
+        #
+        self.assertEqual(Number.divisor.numerator, 1)
+        self.assertIs(Number.numerator.divisor, Number.divisor)
+
+
+class TestStrEnumClass(_EnumTests, _MinimalOutputTests, unittest.TestCase):
     enum_type = StrEnum
     #
     def test_shadowed_attr(self):
@@ -982,64 +1079,141 @@ class TestStrEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase):
         self.assertIs(Book.title.author, Book.author)
 
 
-class TestIntFlag(_EnumTests, _MinimalOutputTests, _FlagTests, unittest.TestCase):
+class TestStrEnumFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase):
+    enum_type = StrEnum
+    #
+    def test_shadowed_attr(self):
+        Book = StrEnum('Book', ('author', 'title'))
+        #
+        self.assertEqual(Book.author.title(), 'Author')
+        self.assertEqual(Book.title.title(), 'Title')
+        self.assertIs(Book.title.author, Book.author)
+
+
+class TestIntFlagClass(_EnumTests, _MinimalOutputTests, _FlagTests, unittest.TestCase):
+    enum_type = IntFlag
+
+
+class TestIntFlagFunction(_EnumTests, _MinimalOutputTests, _FlagTests, unittest.TestCase):
     enum_type = IntFlag
 
 
-class TestMixedInt(_EnumTests, _MixedOutputTests, unittest.TestCase):
+class TestMixedIntClass(_EnumTests, _MixedOutputTests, unittest.TestCase):
     class enum_type(int, Enum): pass
 
 
-class TestMixedStr(_EnumTests, _MixedOutputTests, unittest.TestCase):
+class TestMixedIntFunction(_EnumTests, _MixedOutputTests, unittest.TestCase):
+    enum_type = Enum('enum_type', type=int)
+
+
+class TestMixedStrClass(_EnumTests, _MixedOutputTests, unittest.TestCase):
     class enum_type(str, Enum): pass
 
 
-class TestMixedIntFlag(_EnumTests, _MixedOutputTests, _FlagTests, unittest.TestCase):
+class TestMixedStrFunction(_EnumTests, _MixedOutputTests, unittest.TestCase):
+    enum_type = Enum('enum_type', type=str)
+
+
+class TestMixedIntFlagClass(_EnumTests, _MixedOutputTests, _FlagTests, unittest.TestCase):
     class enum_type(int, Flag): pass
 
 
-class TestMixedDate(_EnumTests, _MixedOutputTests, unittest.TestCase):
+class TestMixedIntFlagFunction(_EnumTests, _MixedOutputTests, _FlagTests, unittest.TestCase):
+    enum_type = Flag('enum_type', type=int)
 
+
+class TestMixedDateClass(_EnumTests, _MixedOutputTests, unittest.TestCase):
+    #
     values = [date(2021, 12, 25), date(2020, 3, 15), date(2019, 11, 27)]
     source_values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)]
-
+    #
     class enum_type(date, Enum):
+        @staticmethod
         def _generate_next_value_(name, start, count, last_values):
             values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)]
             return values[count]
 
 
-class TestMinimalDate(_EnumTests, _MinimalOutputTests, unittest.TestCase):
+class TestMixedDateFunction(_EnumTests, _MixedOutputTests, unittest.TestCase):
+    #
+    values = [date(2021, 12, 25), date(2020, 3, 15), date(2019, 11, 27)]
+    source_values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)]
+    #
+    # staticmethod decorator will be added by EnumType if not present
+    def _generate_next_value_(name, start, count, last_values):
+        values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)]
+        return values[count]
+    #
+    enum_type = Enum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=date)
+
 
+class TestMinimalDateClass(_EnumTests, _MinimalOutputTests, unittest.TestCase):
+    #
     values = [date(2023, 12, 1), date(2016, 2, 29), date(2009, 1, 1)]
     source_values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)]
-
+    #
     class enum_type(date, ReprEnum):
+        # staticmethod decorator will be added by EnumType if absent
         def _generate_next_value_(name, start, count, last_values):
             values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)]
             return values[count]
 
 
-class TestMixedFloat(_EnumTests, _MixedOutputTests, unittest.TestCase):
+class TestMinimalDateFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase):
+    #
+    values = [date(2023, 12, 1), date(2016, 2, 29), date(2009, 1, 1)]
+    source_values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)]
+    #
+    @staticmethod
+    def _generate_next_value_(name, start, count, last_values):
+        values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)]
+        return values[count]
+    #
+    enum_type = ReprEnum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=date)
 
-    values = [1.1, 2.2, 3.3]
 
+class TestMixedFloatClass(_EnumTests, _MixedOutputTests, unittest.TestCase):
+    #
+    values = [1.1, 2.2, 3.3]
+    #
     class enum_type(float, Enum):
         def _generate_next_value_(name, start, count, last_values):
             values = [1.1, 2.2, 3.3]
             return values[count]
 
 
-class TestMinimalFloat(_EnumTests, _MinimalOutputTests, unittest.TestCase):
+class TestMixedFloatFunction(_EnumTests, _MixedOutputTests, unittest.TestCase):
+    #
+    values = [1.1, 2.2, 3.3]
+    #
+    def _generate_next_value_(name, start, count, last_values):
+        values = [1.1, 2.2, 3.3]
+        return values[count]
+    #
+    enum_type = Enum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=float)
 
-    values = [4.4, 5.5, 6.6]
 
+class TestMinimalFloatClass(_EnumTests, _MinimalOutputTests, unittest.TestCase):
+    #
+    values = [4.4, 5.5, 6.6]
+    #
     class enum_type(float, ReprEnum):
         def _generate_next_value_(name, start, count, last_values):
             values = [4.4, 5.5, 6.6]
             return values[count]
 
 
+class TestMinimalFloatFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase):
+    #
+    values = [4.4, 5.5, 6.6]
+    #
+    def _generate_next_value_(name, start, count, last_values):
+        values = [4.4, 5.5, 6.6]
+        return values[count]
+    #
+    enum_type = ReprEnum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=float)
+
+
 class TestSpecial(unittest.TestCase):
     """
     various operations that are not attributable to every possible enum
diff --git a/Misc/NEWS.d/next/Library/2023-08-30-20-10-28.gh-issue-108682.c2gzLQ.rst b/Misc/NEWS.d/next/Library/2023-08-30-20-10-28.gh-issue-108682.c2gzLQ.rst
new file mode 100644 (file)
index 0000000..148d432
--- /dev/null
@@ -0,0 +1,2 @@
+Enum: raise :exc:`TypeError` if ``super().__new__()`` is called from a
+custom ``__new__``.