]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-111874: Call `__set_name__` on objects that define the method inside a `typing...
authorAlex Waygood <Alex.Waygood@Gmail.com>
Mon, 27 Nov 2023 16:34:44 +0000 (16:34 +0000)
committerGitHub <noreply@github.com>
Mon, 27 Nov 2023 16:34:44 +0000 (16:34 +0000)
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
Lib/test/test_typing.py
Lib/typing.py
Misc/NEWS.d/next/Library/2023-11-09-11-07-34.gh-issue-111874.dzYc3j.rst [new file with mode: 0644]

index 31d7fda2f978da98c7bf5030a56090c3b09e65f5..669803177315e37631155cdabbcc0363032d44f1 100644 (file)
@@ -7535,6 +7535,83 @@ class NamedTupleTests(BaseTestCase):
 
         self.assertEqual(CallNamedTuple.__orig_bases__, (NamedTuple,))
 
+    def test_setname_called_on_values_in_class_dictionary(self):
+        class Vanilla:
+            def __set_name__(self, owner, name):
+                self.name = name
+
+        class Foo(NamedTuple):
+            attr = Vanilla()
+
+        foo = Foo()
+        self.assertEqual(len(foo), 0)
+        self.assertNotIn('attr', Foo._fields)
+        self.assertIsInstance(foo.attr, Vanilla)
+        self.assertEqual(foo.attr.name, "attr")
+
+        class Bar(NamedTuple):
+            attr: Vanilla = Vanilla()
+
+        bar = Bar()
+        self.assertEqual(len(bar), 1)
+        self.assertIn('attr', Bar._fields)
+        self.assertIsInstance(bar.attr, Vanilla)
+        self.assertEqual(bar.attr.name, "attr")
+
+    def test_setname_raises_the_same_as_on_other_classes(self):
+        class CustomException(BaseException): pass
+
+        class Annoying:
+            def __set_name__(self, owner, name):
+                raise CustomException
+
+        annoying = Annoying()
+
+        with self.assertRaises(CustomException) as cm:
+            class NormalClass:
+                attr = annoying
+        normal_exception = cm.exception
+
+        with self.assertRaises(CustomException) as cm:
+            class NamedTupleClass(NamedTuple):
+                attr = annoying
+        namedtuple_exception = cm.exception
+
+        self.assertIs(type(namedtuple_exception), CustomException)
+        self.assertIs(type(namedtuple_exception), type(normal_exception))
+
+        self.assertEqual(len(namedtuple_exception.__notes__), 1)
+        self.assertEqual(
+            len(namedtuple_exception.__notes__), len(normal_exception.__notes__)
+        )
+
+        expected_note = (
+            "Error calling __set_name__ on 'Annoying' instance "
+            "'attr' in 'NamedTupleClass'"
+        )
+        self.assertEqual(namedtuple_exception.__notes__[0], expected_note)
+        self.assertEqual(
+            namedtuple_exception.__notes__[0],
+            normal_exception.__notes__[0].replace("NormalClass", "NamedTupleClass")
+        )
+
+    def test_strange_errors_when_accessing_set_name_itself(self):
+        class CustomException(Exception): pass
+
+        class Meta(type):
+            def __getattribute__(self, attr):
+                if attr == "__set_name__":
+                    raise CustomException
+                return object.__getattribute__(self, attr)
+
+        class VeryAnnoying(metaclass=Meta): pass
+
+        very_annoying = VeryAnnoying()
+
+        with self.assertRaises(CustomException):
+            class Foo(NamedTuple):
+                attr = very_annoying
+
 
 class TypedDictTests(BaseTestCase):
     def test_basics_functional_syntax(self):
index 872aca82c4e779b7730d561b6f8d408095847dc0..216f0c141b62afde7fd6b1c10d9ecc1579c99270 100644 (file)
@@ -2743,11 +2743,26 @@ class NamedTupleMeta(type):
             class_getitem = _generic_class_getitem
             nm_tpl.__class_getitem__ = classmethod(class_getitem)
         # update from user namespace without overriding special namedtuple attributes
-        for key in ns:
+        for key, val in ns.items():
             if key in _prohibited:
                 raise AttributeError("Cannot overwrite NamedTuple attribute " + key)
-            elif key not in _special and key not in nm_tpl._fields:
-                setattr(nm_tpl, key, ns[key])
+            elif key not in _special:
+                if key not in nm_tpl._fields:
+                    setattr(nm_tpl, key, val)
+                try:
+                    set_name = type(val).__set_name__
+                except AttributeError:
+                    pass
+                else:
+                    try:
+                        set_name(val, nm_tpl, key)
+                    except BaseException as e:
+                        e.add_note(
+                            f"Error calling __set_name__ on {type(val).__name__!r} "
+                            f"instance {key!r} in {typename!r}"
+                        )
+                        raise
+
         if Generic in bases:
             nm_tpl.__init_subclass__()
         return nm_tpl
diff --git a/Misc/NEWS.d/next/Library/2023-11-09-11-07-34.gh-issue-111874.dzYc3j.rst b/Misc/NEWS.d/next/Library/2023-11-09-11-07-34.gh-issue-111874.dzYc3j.rst
new file mode 100644 (file)
index 0000000..5040820
--- /dev/null
@@ -0,0 +1,4 @@
+When creating a :class:`typing.NamedTuple` class, ensure
+:func:`~object.__set_name__` is called on all objects that define
+``__set_name__`` and exist in the values of the ``NamedTuple`` class's class
+dictionary. Patch by Alex Waygood.