]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-43224: Implement substitution of unpacked TypeVarTuple (GH-31800)
authorSerhiy Storchaka <storchaka@gmail.com>
Fri, 11 Mar 2022 19:43:58 +0000 (21:43 +0200)
committerGitHub <noreply@github.com>
Fri, 11 Mar 2022 19:43:58 +0000 (21:43 +0200)
Lib/test/test_typing.py
Lib/typing.py

index 91b2e77e97b5a331b4208a7bb7513fa563b1ec1b..a6936653bc566fbf3b47b8efbae1db34ab1fb1f7 100644 (file)
@@ -411,6 +411,10 @@ class UnpackTests(BaseTestCase):
 
 class TypeVarTupleTests(BaseTestCase):
 
+    def assertEndsWith(self, string, tail):
+        if not string.endswith(tail):
+            self.fail(f"String {string!r} does not end with {tail!r}")
+
     def test_instance_is_equal_to_itself(self):
         Ts = TypeVarTuple('Ts')
         self.assertEqual(Ts, Ts)
@@ -457,6 +461,56 @@ class TypeVarTupleTests(BaseTestCase):
         self.assertEqual(t2.__args__, (Unpack[Ts],))
         self.assertEqual(t2.__parameters__, (Ts,))
 
+    def test_var_substitution(self):
+        Ts = TypeVarTuple('Ts')
+        T = TypeVar('T')
+        T2 = TypeVar('T2')
+        class G(Generic[Unpack[Ts]]): pass
+
+        for A in G, Tuple:
+            B = A[Unpack[Ts]]
+            if A != Tuple:
+                self.assertEqual(B[()], A[()])
+            self.assertEqual(B[float], A[float])
+            self.assertEqual(B[float, str], A[float, str])
+
+            C = List[A[Unpack[Ts]]]
+            if A != Tuple:
+                self.assertEqual(C[()], List[A[()]])
+            self.assertEqual(C[float], List[A[float]])
+            self.assertEqual(C[float, str], List[A[float, str]])
+
+            D = A[T, Unpack[Ts], T2]
+            with self.assertRaises(TypeError):
+                D[()]
+            with self.assertRaises(TypeError):
+                D[float]
+            self.assertEqual(D[float, str], A[float, str])
+            self.assertEqual(D[float, str, int], A[float, str, int])
+            self.assertEqual(D[float, str, int, bytes], A[float, str, int, bytes])
+
+            E = Tuple[List[T], A[Unpack[Ts]], List[T2]]
+            with self.assertRaises(TypeError):
+                E[()]
+            with self.assertRaises(TypeError):
+                E[float]
+            if A != Tuple:
+                self.assertEqual(E[float, str],
+                                 Tuple[List[float], A[()], List[str]])
+            self.assertEqual(E[float, str, int],
+                             Tuple[List[float], A[str], List[int]])
+            self.assertEqual(E[float, str, int, bytes],
+                             Tuple[List[float], A[str, int], List[bytes]])
+
+    def test_repr_is_correct(self):
+        Ts = TypeVarTuple('Ts')
+        self.assertEqual(repr(Ts), 'Ts')
+        self.assertEqual(repr(Unpack[Ts]), '*Ts')
+        self.assertEqual(repr(tuple[Unpack[Ts]]), 'tuple[*Ts]')
+        self.assertEqual(repr(Tuple[Unpack[Ts]]), 'typing.Tuple[*Ts]')
+        self.assertEqual(repr(Unpack[tuple[Unpack[Ts]]]), '*tuple[*Ts]')
+        self.assertEqual(repr(Unpack[Tuple[Unpack[Ts]]]), '*typing.Tuple[*Ts]')
+
     def test_repr_is_correct(self):
         Ts = TypeVarTuple('Ts')
         self.assertEqual(repr(Ts), 'Ts')
@@ -470,78 +524,51 @@ class TypeVarTupleTests(BaseTestCase):
         Ts = TypeVarTuple('Ts')
         class A(Generic[Unpack[Ts]]): pass
 
-        self.assertTrue(repr(A[()]).endswith('A[()]'))
-        self.assertTrue(repr(A[float]).endswith('A[float]'))
-        self.assertTrue(repr(A[float, str]).endswith('A[float, str]'))
-        self.assertTrue(repr(
-            A[Unpack[tuple[int, ...]]]
-        ).endswith(
-            'A[*tuple[int, ...]]'
-        ))
-        self.assertTrue(repr(
-            A[float, Unpack[tuple[int, ...]]]
-        ).endswith(
-            'A[float, *tuple[int, ...]]'
-        ))
-        self.assertTrue(repr(
-            A[Unpack[tuple[int, ...]], str]
-        ).endswith(
-            'A[*tuple[int, ...], str]'
-        ))
-        self.assertTrue(repr(
-            A[float, Unpack[tuple[int, ...]], str]
-        ).endswith(
-            'A[float, *tuple[int, ...], str]'
-        ))
+        self.assertEndsWith(repr(A[()]), 'A[()]')
+        self.assertEndsWith(repr(A[float]), 'A[float]')
+        self.assertEndsWith(repr(A[float, str]), 'A[float, str]')
+        self.assertEndsWith(repr(A[Unpack[tuple[int, ...]]]),
+                            'A[*tuple[int, ...]]')
+        self.assertEndsWith(repr(A[float, Unpack[tuple[int, ...]]]),
+                            'A[float, *tuple[int, ...]]')
+        self.assertEndsWith(repr(A[Unpack[tuple[int, ...]], str]),
+                            'A[*tuple[int, ...], str]')
+        self.assertEndsWith(repr(A[float, Unpack[tuple[int, ...]], str]),
+                            'A[float, *tuple[int, ...], str]')
 
     def test_variadic_class_alias_repr_is_correct(self):
         Ts = TypeVarTuple('Ts')
         class A(Generic[Unpack[Ts]]): pass
 
         B = A[Unpack[Ts]]
-        self.assertTrue(repr(B).endswith('A[*Ts]'))
-        with self.assertRaises(NotImplementedError):
-            B[()]
-        with self.assertRaises(NotImplementedError):
-            B[float]
-        with self.assertRaises(NotImplementedError):
-            B[float, str]
+        self.assertEndsWith(repr(B), 'A[*Ts]')
+        self.assertEndsWith(repr(B[()]), 'A[()]')
+        self.assertEndsWith(repr(B[float]), 'A[float]')
+        self.assertEndsWith(repr(B[float, str]), 'A[float, str]')
 
         C = A[Unpack[Ts], int]
-        self.assertTrue(repr(C).endswith('A[*Ts, int]'))
-        with self.assertRaises(NotImplementedError):
-            C[()]
-        with self.assertRaises(NotImplementedError):
-            C[float]
-        with self.assertRaises(NotImplementedError):
-            C[float, str]
+        self.assertEndsWith(repr(C), 'A[*Ts, int]')
+        self.assertEndsWith(repr(C[()]), 'A[int]')
+        self.assertEndsWith(repr(C[float]), 'A[float, int]')
+        self.assertEndsWith(repr(C[float, str]), 'A[float, str, int]')
 
         D = A[int, Unpack[Ts]]
-        self.assertTrue(repr(D).endswith('A[int, *Ts]'))
-        with self.assertRaises(NotImplementedError):
-            D[()]
-        with self.assertRaises(NotImplementedError):
-            D[float]
-        with self.assertRaises(NotImplementedError):
-            D[float, str]
+        self.assertEndsWith(repr(D), 'A[int, *Ts]')
+        self.assertEndsWith(repr(D[()]), 'A[int]')
+        self.assertEndsWith(repr(D[float]), 'A[int, float]')
+        self.assertEndsWith(repr(D[float, str]), 'A[int, float, str]')
 
         E = A[int, Unpack[Ts], str]
-        self.assertTrue(repr(E).endswith('A[int, *Ts, str]'))
-        with self.assertRaises(NotImplementedError):
-            E[()]
-        with self.assertRaises(NotImplementedError):
-            E[float]
-        with self.assertRaises(NotImplementedError):
-            E[float, bool]
+        self.assertEndsWith(repr(E), 'A[int, *Ts, str]')
+        self.assertEndsWith(repr(E[()]), 'A[int, str]')
+        self.assertEndsWith(repr(E[float]), 'A[int, float, str]')
+        self.assertEndsWith(repr(E[float, str]), 'A[int, float, str, str]')
 
         F = A[Unpack[Ts], Unpack[tuple[str, ...]]]
-        self.assertTrue(repr(F).endswith('A[*Ts, *tuple[str, ...]]'))
-        with self.assertRaises(NotImplementedError):
-            F[()]
-        with self.assertRaises(NotImplementedError):
-            F[float]
-        with self.assertRaises(NotImplementedError):
-            F[float, int]
+        self.assertEndsWith(repr(F), 'A[*Ts, *tuple[str, ...]]')
+        self.assertEndsWith(repr(F[()]), 'A[*tuple[str, ...]]')
+        self.assertEndsWith(repr(F[float]), 'A[float, *tuple[str, ...]]')
+        self.assertEndsWith(repr(F[float, str]), 'A[float, str, *tuple[str, ...]]')
 
     def test_cannot_subclass_class(self):
         with self.assertRaises(TypeError):
index 062c01ef2a9b9eeb15e50d129726ef7f7f37d34b..842554f193ca791e899fe31c92243172208b818d 100644 (file)
@@ -1297,30 +1297,39 @@ class _GenericAlias(_BaseGenericAlias, _root=True):
         # anything more exotic than a plain `TypeVar`, we need to consider
         # edge cases.
 
-        if any(isinstance(p, TypeVarTuple) for p in self.__parameters__):
-            raise NotImplementedError(
-                "Type substitution for TypeVarTuples is not yet implemented"
-            )
+        params = self.__parameters__
         # In the example above, this would be {T3: str}
-        new_arg_by_param = dict(zip(self.__parameters__, args))
+        new_arg_by_param = {}
+        for i, param in enumerate(params):
+            if isinstance(param, TypeVarTuple):
+                j = len(args) - (len(params) - i - 1)
+                if j < i:
+                    raise TypeError(f"Too few arguments for {self}")
+                new_arg_by_param.update(zip(params[:i], args[:i]))
+                new_arg_by_param[param] = args[i: j]
+                new_arg_by_param.update(zip(params[i + 1:], args[j:]))
+                break
+        else:
+            new_arg_by_param.update(zip(params, args))
 
         new_args = []
         for old_arg in self.__args__:
 
-            if _is_unpacked_typevartuple(old_arg):
-                original_typevartuple = old_arg.__parameters__[0]
-                new_arg = new_arg_by_param[original_typevartuple]
+            substfunc = getattr(old_arg, '__typing_subst__', None)
+            if substfunc:
+                new_arg = substfunc(new_arg_by_param[old_arg])
             else:
-                substfunc = getattr(old_arg, '__typing_subst__', None)
-                if substfunc:
-                    new_arg = substfunc(new_arg_by_param[old_arg])
+                subparams = getattr(old_arg, '__parameters__', ())
+                if not subparams:
+                    new_arg = old_arg
                 else:
-                    subparams = getattr(old_arg, '__parameters__', ())
-                    if not subparams:
-                        new_arg = old_arg
-                    else:
-                        subargs = tuple(new_arg_by_param[x] for x in subparams)
-                        new_arg = old_arg[subargs]
+                    subargs = []
+                    for x in subparams:
+                        if isinstance(x, TypeVarTuple):
+                            subargs.extend(new_arg_by_param[x])
+                        else:
+                            subargs.append(new_arg_by_param[x])
+                    new_arg = old_arg[tuple(subargs)]
 
             if self.__origin__ == collections.abc.Callable and isinstance(new_arg, tuple):
                 # Consider the following `Callable`.
@@ -1612,6 +1621,12 @@ class _UnpackGenericAlias(_GenericAlias, _root=True):
         # a single item.
         return '*' + repr(self.__args__[0])
 
+    def __getitem__(self, args):
+        if (len(self.__parameters__) == 1 and
+                isinstance(self.__parameters__[0], TypeVarTuple)):
+            return args
+        return super().__getitem__(args)
+
 
 class Generic:
     """Abstract base class for generic types.