]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-96784: Cover more typing special forms in `get_args()` (#96791)
authorNikita Sobolev <mail@sobolevn.me>
Wed, 14 Sep 2022 02:35:52 +0000 (05:35 +0300)
committerGitHub <noreply@github.com>
Wed, 14 Sep 2022 02:35:52 +0000 (19:35 -0700)
Lib/test/test_typing.py

index d0f6a185384f06bc361a1ce311bfb7b25daa6e7c..e4cf3c7d53b657248f8861dd157ea28e6f621b72 100644 (file)
@@ -4965,17 +4965,28 @@ class GetUtilitiesTestCase(TestCase):
         class C(Generic[T]): pass
         self.assertEqual(get_args(C[int]), (int,))
         self.assertEqual(get_args(C[T]), (T,))
+        self.assertEqual(get_args(typing.SupportsAbs[int]), (int,))  # Protocol
+        self.assertEqual(get_args(typing.SupportsAbs[T]), (T,))
+        self.assertEqual(get_args(Point2DGeneric[int]), (int,))  # TypedDict
+        self.assertEqual(get_args(Point2DGeneric[T]), (T,))
+        self.assertEqual(get_args(T), ())
         self.assertEqual(get_args(int), ())
+        self.assertEqual(get_args(Any), ())
+        self.assertEqual(get_args(Self), ())
+        self.assertEqual(get_args(LiteralString), ())
         self.assertEqual(get_args(ClassVar[int]), (int,))
         self.assertEqual(get_args(Union[int, str]), (int, str))
         self.assertEqual(get_args(Literal[42, 43]), (42, 43))
         self.assertEqual(get_args(Final[List[int]]), (List[int],))
+        self.assertEqual(get_args(Optional[int]), (int, type(None)))
+        self.assertEqual(get_args(Union[int, None]), (int, type(None)))
         self.assertEqual(get_args(Union[int, Tuple[T, int]][str]),
                          (int, Tuple[str, int]))
         self.assertEqual(get_args(typing.Dict[int, Tuple[T, T]][Optional[int]]),
                          (int, Tuple[Optional[int], Optional[int]]))
         self.assertEqual(get_args(Callable[[], T][int]), ([], int))
         self.assertEqual(get_args(Callable[..., int]), (..., int))
+        self.assertEqual(get_args(Callable[[int], str]), ([int], str))
         self.assertEqual(get_args(Union[int, Callable[[Tuple[T, ...]], str]]),
                          (int, Callable[[Tuple[T, ...]], str]))
         self.assertEqual(get_args(Tuple[int, ...]), (int, ...))
@@ -4992,12 +5003,25 @@ class GetUtilitiesTestCase(TestCase):
         self.assertEqual(get_args(collections.abc.Callable[[int], str]),
                          get_args(Callable[[int], str]))
         P = ParamSpec('P')
+        self.assertEqual(get_args(P), ())
+        self.assertEqual(get_args(P.args), ())
+        self.assertEqual(get_args(P.kwargs), ())
         self.assertEqual(get_args(Callable[P, int]), (P, int))
+        self.assertEqual(get_args(collections.abc.Callable[P, int]), (P, int))
         self.assertEqual(get_args(Callable[Concatenate[int, P], int]),
                          (Concatenate[int, P], int))
+        self.assertEqual(get_args(collections.abc.Callable[Concatenate[int, P], int]),
+                         (Concatenate[int, P], int))
+        self.assertEqual(get_args(Concatenate[int, str, P]), (int, str, P))
         self.assertEqual(get_args(list | str), (list, str))
         self.assertEqual(get_args(Required[int]), (int,))
         self.assertEqual(get_args(NotRequired[int]), (int,))
+        self.assertEqual(get_args(TypeAlias), ())
+        self.assertEqual(get_args(TypeGuard[int]), (int,))
+        Ts = TypeVarTuple('Ts')
+        self.assertEqual(get_args(Ts), ())
+        self.assertEqual(get_args(Unpack[Ts]), (Ts,))
+        self.assertEqual(get_args(tuple[Unpack[Ts]]), (Unpack[Ts],))
 
 
 class CollectionsAbcTests(BaseTestCase):
@@ -5762,9 +5786,12 @@ class NamedTupleTests(BaseTestCase):
         for G in X, Y:
             with self.subTest(type=G):
                 self.assertEqual(G.__parameters__, (T,))
+                self.assertEqual(G[T].__args__, (T,))
+                self.assertEqual(get_args(G[T]), (T,))
                 A = G[int]
                 self.assertIs(A.__origin__, G)
                 self.assertEqual(A.__args__, (int,))
+                self.assertEqual(get_args(A), (int,))
                 self.assertEqual(A.__parameters__, ())
 
                 a = A(3)