]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-104600: Make function.__type_params__ writable (#104601)
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Thu, 18 May 2023 23:45:37 +0000 (16:45 -0700)
committerGitHub <noreply@github.com>
Thu, 18 May 2023 23:45:37 +0000 (16:45 -0700)
Lib/functools.py
Lib/test/test_funcattrs.py
Lib/test/test_functools.py
Lib/test/test_type_params.py
Misc/NEWS.d/next/Library/2023-05-17-21-01-48.gh-issue-104600.E6CK35.rst [new file with mode: 0644]
Objects/funcobject.c

index aaf4291150fbbf2bd388f931d9c41d22c35b51f1..72b2103e7a5544c8599038a65bbf4d943bb2b0b0 100644 (file)
@@ -30,7 +30,7 @@ from types import GenericAlias
 # wrapper functions that can handle naive introspection
 
 WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__',
-                       '__annotations__')
+                       '__annotations__', '__type_params__')
 WRAPPER_UPDATES = ('__dict__',)
 def update_wrapper(wrapper,
                    wrapped,
index 77977d0ae966f8a1682b6fd8110bf0a6b495831e..e08d72877d8aef334a029953fd1f93ddfc5812bf 100644 (file)
@@ -1,5 +1,6 @@
 import textwrap
 import types
+import typing
 import unittest
 
 
@@ -190,6 +191,20 @@ class FunctionPropertiesTest(FuncAttrsTest):
         # __qualname__ must be a string
         self.cannot_set_attr(self.b, '__qualname__', 7, TypeError)
 
+    def test___type_params__(self):
+        def generic[T](): pass
+        def not_generic(): pass
+        T, = generic.__type_params__
+        self.assertIsInstance(T, typing.TypeVar)
+        self.assertEqual(generic.__type_params__, (T,))
+        self.assertEqual(not_generic.__type_params__, ())
+        with self.assertRaises(TypeError):
+            del not_generic.__type_params__
+        with self.assertRaises(TypeError):
+            not_generic.__type_params__ = 42
+        not_generic.__type_params__ = (T,)
+        self.assertEqual(not_generic.__type_params__, (T,))
+
     def test___code__(self):
         num_one, num_two = 7, 8
         def a(): pass
index af286052a7d5602c1c870375647800dd8973b7f3..d668fa4c3adf5c2540040ac299b0639324cde6c2 100644 (file)
@@ -617,7 +617,7 @@ class TestUpdateWrapper(unittest.TestCase):
 
 
     def _default_update(self):
-        def f(a:'This is a new annotation'):
+        def f[T](a:'This is a new annotation'):
             """This is a test"""
             pass
         f.attr = 'This is also a test'
@@ -630,12 +630,14 @@ class TestUpdateWrapper(unittest.TestCase):
     def test_default_update(self):
         wrapper, f = self._default_update()
         self.check_wrapper(wrapper, f)
+        T, = f.__type_params__
         self.assertIs(wrapper.__wrapped__, f)
         self.assertEqual(wrapper.__name__, 'f')
         self.assertEqual(wrapper.__qualname__, f.__qualname__)
         self.assertEqual(wrapper.attr, 'This is also a test')
         self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
         self.assertNotIn('b', wrapper.__annotations__)
+        self.assertEqual(wrapper.__type_params__, (T,))
 
     @unittest.skipIf(sys.flags.optimize >= 2,
                      "Docstrings are omitted with -O2 and above")
index 96bd1fa0bab99080f7f23e514426abb90cde92bb..466e3bd43a68e53189708d8ce6bc4d313b48f7e2 100644 (file)
@@ -843,5 +843,5 @@ class TypeParamsTypeParamsDunder(unittest.TestCase):
             func.__type_params__ = ()
         """
 
-        with self.assertRaisesRegex(AttributeError, "attribute '__type_params__' of 'function' objects is not writable"):
-            run_code(code)
+        ns = run_code(code)
+        self.assertEqual(ns["func"].__type_params__, ())
diff --git a/Misc/NEWS.d/next/Library/2023-05-17-21-01-48.gh-issue-104600.E6CK35.rst b/Misc/NEWS.d/next/Library/2023-05-17-21-01-48.gh-issue-104600.E6CK35.rst
new file mode 100644 (file)
index 0000000..64f81e1
--- /dev/null
@@ -0,0 +1,2 @@
+:func:`functools.update_wrapper` now sets the ``__type_params__`` attribute
+(added by :pep:`695`).
index 69898bf722d61f4adbdcb81697e8e49108ba7729..753038600aa858d55294dcb986130809c60323d8 100644 (file)
@@ -665,6 +665,20 @@ func_get_type_params(PyFunctionObject *op, void *Py_UNUSED(ignored))
     return Py_NewRef(op->func_typeparams);
 }
 
+static int
+func_set_type_params(PyFunctionObject *op, PyObject *value, void *Py_UNUSED(ignored))
+{
+    /* Not legal to del f.__type_params__ or to set it to anything
+     * other than a tuple object. */
+    if (value == NULL || !PyTuple_Check(value)) {
+        PyErr_SetString(PyExc_TypeError,
+                        "__type_params__ must be set to a tuple");
+        return -1;
+    }
+    Py_XSETREF(op->func_typeparams, Py_NewRef(value));
+    return 0;
+}
+
 PyObject *
 _Py_set_function_type_params(PyThreadState *Py_UNUSED(ignored), PyObject *func,
                              PyObject *type_params)
@@ -687,7 +701,8 @@ static PyGetSetDef func_getsetlist[] = {
     {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict},
     {"__name__", (getter)func_get_name, (setter)func_set_name},
     {"__qualname__", (getter)func_get_qualname, (setter)func_set_qualname},
-    {"__type_params__", (getter)func_get_type_params, NULL},
+    {"__type_params__", (getter)func_get_type_params,
+     (setter)func_set_type_params},
     {NULL} /* Sentinel */
 };