]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-124787: Fix `TypeAliasType` and incorrect `type_params` (#124795)
authorsobolevn <mail@sobolevn.me>
Fri, 11 Oct 2024 14:39:18 +0000 (17:39 +0300)
committerGitHub <noreply@github.com>
Fri, 11 Oct 2024 14:39:18 +0000 (17:39 +0300)
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
Lib/test/test_type_aliases.py
Misc/NEWS.d/next/Library/2024-09-30-20-46-32.gh-issue-124787.3FnJnP.rst [new file with mode: 0644]
Objects/typevarobject.c

index ebb65d8c6cf81b3e0baded4e5c19066947cd1f40..230bbe646baf2818043c84968a2f0de080a9547f 100644 (file)
@@ -4,7 +4,9 @@ import unittest
 from test.support import check_syntax_error, run_code
 from test.typinganndata import mod_generics_cache
 
-from typing import Callable, TypeAliasType, TypeVar, get_args
+from typing import (
+    Callable, TypeAliasType, TypeVar, TypeVarTuple, ParamSpec, get_args,
+)
 
 
 class TypeParamsInvalidTest(unittest.TestCase):
@@ -225,6 +227,46 @@ class TypeAliasConstructorTest(unittest.TestCase):
         ):
             TA[int]
 
+    def test_type_params_order_with_defaults(self):
+        HasNoDefaultT = TypeVar("HasNoDefaultT")
+        WithDefaultT = TypeVar("WithDefaultT", default=int)
+
+        HasNoDefaultP = ParamSpec("HasNoDefaultP")
+        WithDefaultP = ParamSpec("WithDefaultP", default=HasNoDefaultP)
+
+        HasNoDefaultTT = TypeVarTuple("HasNoDefaultTT")
+        WithDefaultTT = TypeVarTuple("WithDefaultTT", default=HasNoDefaultTT)
+
+        for type_params in [
+            (HasNoDefaultT, WithDefaultT),
+            (HasNoDefaultP, WithDefaultP),
+            (HasNoDefaultTT, WithDefaultTT),
+        ]:
+            with self.subTest(type_params=type_params):
+                TypeAliasType("A", int, type_params=type_params)  # ok
+
+        msg = "follows default type parameter"
+        for type_params in [
+            (WithDefaultT, HasNoDefaultT),
+            (WithDefaultP, HasNoDefaultP),
+            (WithDefaultTT, HasNoDefaultTT),
+            (WithDefaultT, HasNoDefaultP),  # different types
+        ]:
+            with self.subTest(type_params=type_params):
+                with self.assertRaisesRegex(TypeError, msg):
+                    TypeAliasType("A", int, type_params=type_params)
+
+    def test_expects_type_like(self):
+        T = TypeVar("T")
+
+        msg = "Expected a type param"
+        with self.assertRaisesRegex(TypeError, msg):
+            TypeAliasType("A", int, type_params=(1,))
+        with self.assertRaisesRegex(TypeError, msg):
+            TypeAliasType("A", int, type_params=(1, 2))
+        with self.assertRaisesRegex(TypeError, msg):
+            TypeAliasType("A", int, type_params=(T, 2))
+
     def test_keywords(self):
         TA = TypeAliasType(name="TA", value=int)
         self.assertEqual(TA.__name__, "TA")
diff --git a/Misc/NEWS.d/next/Library/2024-09-30-20-46-32.gh-issue-124787.3FnJnP.rst b/Misc/NEWS.d/next/Library/2024-09-30-20-46-32.gh-issue-124787.3FnJnP.rst
new file mode 100644 (file)
index 0000000..d9d1bbc
--- /dev/null
@@ -0,0 +1,4 @@
+Fix :class:`typing.TypeAliasType` with incorrect ``type_params`` argument.
+Now it raises a :exc:`TypeError` when a type parameter without a default
+follows one with a default, and when an entry in the ``type_params`` tuple
+is not a type parameter object.
index 51d93ed8b5ba8c85b7405bb277ecf75d0a3521e3..91cc37c9a7263654588c09eff46669872f77bf40 100644 (file)
@@ -1799,6 +1799,24 @@ _Py_make_typevartuple(PyThreadState *Py_UNUSED(ignored), PyObject *v)
     return (PyObject *)typevartuple_alloc(v, NULL, NULL);
 }
 
+static PyObject *
+get_type_param_default(PyThreadState *ts, PyObject *typeparam) {
+    // Does not modify refcount of existing objects.
+    if (Py_IS_TYPE(typeparam, ts->interp->cached_objects.typevar_type)) {
+        return typevar_default((typevarobject *)typeparam, NULL);
+    }
+    else if (Py_IS_TYPE(typeparam, ts->interp->cached_objects.paramspec_type)) {
+        return paramspec_default((paramspecobject *)typeparam, NULL);
+    }
+    else if (Py_IS_TYPE(typeparam, ts->interp->cached_objects.typevartuple_type)) {
+        return typevartuple_default((typevartupleobject *)typeparam, NULL);
+    }
+    else {
+        PyErr_Format(PyExc_TypeError, "Expected a type param, got %R", typeparam);
+        return NULL;
+    }
+}
+
 static void
 typealias_dealloc(PyObject *self)
 {
@@ -1906,25 +1924,75 @@ static PyGetSetDef typealias_getset[] = {
     {0}
 };
 
-static typealiasobject *
-typealias_alloc(PyObject *name, PyObject *type_params, PyObject *compute_value,
-                PyObject *value, PyObject *module)
-{
-    typealiasobject *ta = PyObject_GC_New(typealiasobject, &_PyTypeAlias_Type);
-    if (ta == NULL) {
+static PyObject *
+typealias_check_type_params(PyObject *type_params, int *err) {
+    // Can return type_params or NULL without exception set.
+    // Does not change the reference count of type_params,
+    // sets `*err` to 1 when error happens and sets an exception,
+    // otherwise `*err` is set to 0.
+    *err = 0;
+    if (type_params == NULL) {
         return NULL;
     }
-    ta->name = Py_NewRef(name);
+
+    assert(PyTuple_Check(type_params));
+    Py_ssize_t length = PyTuple_GET_SIZE(type_params);
+    if (!length) {  // 0-length tuples are the same as `NULL`.
+        return NULL;
+    }
+
+    PyThreadState *ts = _PyThreadState_GET();
+    int default_seen = 0;
+    for (Py_ssize_t index = 0; index < length; index++) {
+        PyObject *type_param = PyTuple_GET_ITEM(type_params, index);
+        PyObject *dflt = get_type_param_default(ts, type_param);
+        if (dflt == NULL) {
+            *err = 1;
+            return NULL;
+        }
+        if (dflt == &_Py_NoDefaultStruct) {
+            if (default_seen) {
+                *err = 1;
+                PyErr_Format(PyExc_TypeError,
+                                "non-default type parameter '%R' "
+                                "follows default type parameter",
+                                type_param);
+                return NULL;
+            }
+        } else {
+            default_seen = 1;
+            Py_DECREF(dflt);
+        }
+    }
+
+    return type_params;
+}
+
+static PyObject *
+typelias_convert_type_params(PyObject *type_params)
+{
     if (
         type_params == NULL
         || Py_IsNone(type_params)
         || (PyTuple_Check(type_params) && PyTuple_GET_SIZE(type_params) == 0)
     ) {
-        ta->type_params = NULL;
+        return NULL;
     }
     else {
-        ta->type_params = Py_NewRef(type_params);
+        return type_params;
     }
+}
+
+static typealiasobject *
+typealias_alloc(PyObject *name, PyObject *type_params, PyObject *compute_value,
+                PyObject *value, PyObject *module)
+{
+    typealiasobject *ta = PyObject_GC_New(typealiasobject, &_PyTypeAlias_Type);
+    if (ta == NULL) {
+        return NULL;
+    }
+    ta->name = Py_NewRef(name);
+    ta->type_params = Py_XNewRef(type_params);
     ta->compute_value = Py_XNewRef(compute_value);
     ta->value = Py_XNewRef(value);
     ta->module = Py_XNewRef(module);
@@ -2002,11 +2070,18 @@ typealias_new_impl(PyTypeObject *type, PyObject *name, PyObject *value,
         PyErr_SetString(PyExc_TypeError, "type_params must be a tuple");
         return NULL;
     }
+
+    int err = 0;
+    PyObject *checked_params = typealias_check_type_params(type_params, &err);
+    if (err) {
+        return NULL;
+    }
+
     PyObject *module = caller();
     if (module == NULL) {
         return NULL;
     }
-    PyObject *ta = (PyObject *)typealias_alloc(name, type_params, NULL, value,
+    PyObject *ta = (PyObject *)typealias_alloc(name, checked_params, NULL, value,
                                                module);
     Py_DECREF(module);
     return ta;
@@ -2072,7 +2147,7 @@ _Py_make_typealias(PyThreadState* unused, PyObject *args)
     assert(PyTuple_GET_SIZE(args) == 3);
     PyObject *name = PyTuple_GET_ITEM(args, 0);
     assert(PyUnicode_Check(name));
-    PyObject *type_params = PyTuple_GET_ITEM(args, 1);
+    PyObject *type_params = typelias_convert_type_params(PyTuple_GET_ITEM(args, 1));
     PyObject *compute_value = PyTuple_GET_ITEM(args, 2);
     assert(PyFunction_Check(compute_value));
     return (PyObject *)typealias_alloc(name, type_params, compute_value, NULL, NULL);