]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-44490: Add __parameters__ and __getitem__ to types.Union (GH-26980)
authorYurii Karabas <1998uriyyo@gmail.com>
Tue, 6 Jul 2021 18:04:33 +0000 (21:04 +0300)
committerGitHub <noreply@github.com>
Tue, 6 Jul 2021 18:04:33 +0000 (11:04 -0700)
Co-authored-by: Ken Jin <28750310+Fidget-Spinner@users.noreply.github.com>
Co-authored-by: Guido van Rossum <gvanrossum@gmail.com>
Include/genericaliasobject.h
Lib/test/test_types.py
Misc/NEWS.d/next/Core and Builtins/2021-07-01-11-59-34.bpo-44490.xY80VR.rst [new file with mode: 0644]
Objects/genericaliasobject.c
Objects/unionobject.c

index cf002976b27cd7712ee80bf348f55c5c562bc12d..4ce9244bb4ce79ded2668e48096e1510fe7faf8d 100644 (file)
@@ -5,6 +5,11 @@
 extern "C" {
 #endif
 
+#ifndef Py_LIMITED_API
+PyAPI_FUNC(PyObject *) _Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *);
+PyAPI_FUNC(PyObject *) _Py_make_parameters(PyObject *);
+#endif
+
 PyAPI_FUNC(PyObject *) Py_GenericAlias(PyObject *, PyObject *);
 PyAPI_DATA(PyTypeObject) Py_GenericAliasType;
 
index ae7b17bd590e61b9a2d14a676420efef2867d427..7f7ce86ff08ef3722ebbec8f12ca4fd5d02b89cd 100644 (file)
@@ -666,6 +666,16 @@ class TypesTests(unittest.TestCase):
         assert TV | str == typing.Union[TV, str]
         assert str | TV == typing.Union[str, TV]
 
+    def test_union_parameter_chaining(self):
+        T = typing.TypeVar("T")
+        S = typing.TypeVar("S")
+
+        self.assertEqual((float | list[T])[int], float | list[int])
+        self.assertEqual(list[int | list[T]].__parameters__, (T,))
+        self.assertEqual(list[int | list[T]][str], list[int | list[str]])
+        self.assertEqual((list[T] | list[S]).__parameters__, (T, S))
+        self.assertEqual((list[T] | list[S])[int, T], list[int] | list[T])
+
     def test_or_type_operator_with_forward(self):
         T = typing.TypeVar('T')
         ForwardAfter = T | 'Forward'
diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-07-01-11-59-34.bpo-44490.xY80VR.rst b/Misc/NEWS.d/next/Core and Builtins/2021-07-01-11-59-34.bpo-44490.xY80VR.rst
new file mode 100644 (file)
index 0000000..4912bca
--- /dev/null
@@ -0,0 +1,2 @@
+Add ``__parameters__`` attribute and ``__getitem__``
+operator to ``types.Union``. Patch provided by Yurii Karabas.
index 803912b7a18a49e52cf1d69e6a056175255139dd..d3d387193d3572e93ea65ebdb30cbd30b77cd8b6 100644 (file)
@@ -198,8 +198,8 @@ tuple_add(PyObject *self, Py_ssize_t len, PyObject *item)
     return 0;
 }
 
-static PyObject *
-make_parameters(PyObject *args)
+PyObject *
+_Py_make_parameters(PyObject *args)
 {
     Py_ssize_t nargs = PyTuple_GET_SIZE(args);
     Py_ssize_t len = nargs;
@@ -294,18 +294,10 @@ subs_tvars(PyObject *obj, PyObject *params, PyObject **argitems)
     return obj;
 }
 
-static PyObject *
-ga_getitem(PyObject *self, PyObject *item)
+PyObject *
+_Py_subs_parameters(PyObject *self, PyObject *args, PyObject *parameters, PyObject *item)
 {
-    gaobject *alias = (gaobject *)self;
-    // do a lookup for __parameters__ so it gets populated (if not already)
-    if (alias->parameters == NULL) {
-        alias->parameters = make_parameters(alias->args);
-        if (alias->parameters == NULL) {
-            return NULL;
-        }
-    }
-    Py_ssize_t nparams = PyTuple_GET_SIZE(alias->parameters);
+    Py_ssize_t nparams = PyTuple_GET_SIZE(parameters);
     if (nparams == 0) {
         return PyErr_Format(PyExc_TypeError,
                             "There are no type variables left in %R",
@@ -320,32 +312,32 @@ ga_getitem(PyObject *self, PyObject *item)
                             nitems > nparams ? "many" : "few",
                             self);
     }
-    /* Replace all type variables (specified by alias->parameters)
+    /* Replace all type variables (specified by parameters)
        with corresponding values specified by argitems.
         t = list[T];          t[int]      -> newargs = [int]
         t = dict[str, T];     t[int]      -> newargs = [str, int]
         t = dict[T, list[S]]; t[str, int] -> newargs = [str, list[int]]
      */
-    Py_ssize_t nargs = PyTuple_GET_SIZE(alias->args);
+    Py_ssize_t nargs = PyTuple_GET_SIZE(args);
     PyObject *newargs = PyTuple_New(nargs);
     if (newargs == NULL) {
         return NULL;
     }
     for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
-        PyObject *arg = PyTuple_GET_ITEM(alias->args, iarg);
+        PyObject *arg = PyTuple_GET_ITEM(args, iarg);
         int typevar = is_typevar(arg);
         if (typevar < 0) {
             Py_DECREF(newargs);
             return NULL;
         }
         if (typevar) {
-            Py_ssize_t iparam = tuple_index(alias->parameters, nparams, arg);
+            Py_ssize_t iparam = tuple_index(parameters, nparams, arg);
             assert(iparam >= 0);
             arg = argitems[iparam];
             Py_INCREF(arg);
         }
         else {
-            arg = subs_tvars(arg, alias->parameters, argitems);
+            arg = subs_tvars(arg, parameters, argitems);
             if (arg == NULL) {
                 Py_DECREF(newargs);
                 return NULL;
@@ -354,6 +346,26 @@ ga_getitem(PyObject *self, PyObject *item)
         PyTuple_SET_ITEM(newargs, iarg, arg);
     }
 
+    return newargs;
+}
+
+static PyObject *
+ga_getitem(PyObject *self, PyObject *item)
+{
+    gaobject *alias = (gaobject *)self;
+    // Populate __parameters__ if needed.
+    if (alias->parameters == NULL) {
+        alias->parameters = _Py_make_parameters(alias->args);
+        if (alias->parameters == NULL) {
+            return NULL;
+        }
+    }
+
+    PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
+    if (newargs == NULL) {
+        return NULL;
+    }
+
     PyObject *res = Py_GenericAlias(alias->origin, newargs);
 
     Py_DECREF(newargs);
@@ -550,7 +562,7 @@ ga_parameters(PyObject *self, void *unused)
 {
     gaobject *alias = (gaobject *)self;
     if (alias->parameters == NULL) {
-        alias->parameters = make_parameters(alias->args);
+        alias->parameters = _Py_make_parameters(alias->args);
         if (alias->parameters == NULL) {
             return NULL;
         }
index 8435763b5ea7ca71f1b447739188e9c71e496cd0..d2a10dfec858ea4a061b08abef4346fb50082c0d 100644 (file)
@@ -8,6 +8,7 @@
 typedef struct {
     PyObject_HEAD
     PyObject *args;
+    PyObject *parameters;
 } unionobject;
 
 static void
@@ -18,6 +19,7 @@ unionobject_dealloc(PyObject *self)
     _PyObject_GC_UNTRACK(self);
 
     Py_XDECREF(alias->args);
+    Py_XDECREF(alias->parameters);
     Py_TYPE(self)->tp_free(self);
 }
 
@@ -26,6 +28,7 @@ union_traverse(PyObject *self, visitproc visit, void *arg)
 {
     unionobject *alias = (unionobject *)self;
     Py_VISIT(alias->args);
+    Py_VISIT(alias->parameters);
     return 0;
 }
 
@@ -435,6 +438,53 @@ static PyMethodDef union_methods[] = {
         {"__subclasscheck__", union_subclasscheck, METH_O},
         {0}};
 
+
+static PyObject *
+union_getitem(PyObject *self, PyObject *item)
+{
+    unionobject *alias = (unionobject *)self;
+    // Populate __parameters__ if needed.
+    if (alias->parameters == NULL) {
+        alias->parameters = _Py_make_parameters(alias->args);
+        if (alias->parameters == NULL) {
+            return NULL;
+        }
+    }
+
+    PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
+    if (newargs == NULL) {
+        return NULL;
+    }
+
+    PyObject *res = _Py_Union(newargs);
+
+    Py_DECREF(newargs);
+    return res;
+}
+
+static PyMappingMethods union_as_mapping = {
+    .mp_subscript = union_getitem,
+};
+
+static PyObject *
+union_parameters(PyObject *self, void *Py_UNUSED(unused))
+{
+    unionobject *alias = (unionobject *)self;
+    if (alias->parameters == NULL) {
+        alias->parameters = _Py_make_parameters(alias->args);
+        if (alias->parameters == NULL) {
+            return NULL;
+        }
+    }
+    Py_INCREF(alias->parameters);
+    return alias->parameters;
+}
+
+static PyGetSetDef union_properties[] = {
+    {"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.Union.", NULL},
+    {0}
+};
+
 static PyNumberMethods union_as_number = {
         .nb_or = _Py_union_type_or, // Add __or__ function
 };
@@ -456,8 +506,10 @@ PyTypeObject _Py_UnionType = {
     .tp_members = union_members,
     .tp_methods = union_methods,
     .tp_richcompare = union_richcompare,
+    .tp_as_mapping = &union_as_mapping,
     .tp_as_number = &union_as_number,
     .tp_repr = union_repr,
+    .tp_getset = union_properties,
 };
 
 PyObject *
@@ -489,6 +541,7 @@ _Py_Union(PyObject *args)
         return NULL;
     }
 
+    result->parameters = NULL;
     result->args = dedup_and_flatten_args(args);
     _PyObject_GC_TRACK(result);
     if (result->args == NULL) {