]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_concurency.py` (#2474)
authorScirlat Danut <danut.scirlat@gmail.com>
Sun, 4 Feb 2024 17:15:40 +0000 (19:15 +0200)
committerGitHub <noreply@github.com>
Sun, 4 Feb 2024 17:15:40 +0000 (17:15 +0000)
Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
tests/test_concurrency.py

index 61fe5ff7b27efeb6a4183af9ec6bf76a3550d57d..aba3ceb1ae0b49929b8fd616a22829583b7548f5 100644 (file)
@@ -1,4 +1,5 @@
 from contextvars import ContextVar
+from typing import Callable, Iterator
 
 import anyio
 import pytest
@@ -8,17 +9,20 @@ from starlette.concurrency import iterate_in_threadpool, run_until_first_complet
 from starlette.requests import Request
 from starlette.responses import Response
 from starlette.routing import Route
+from starlette.testclient import TestClient
+
+TestClientFactory = Callable[..., TestClient]
 
 
 @pytest.mark.anyio
-async def test_run_until_first_complete():
+async def test_run_until_first_complete() -> None:
     task1_finished = anyio.Event()
     task2_finished = anyio.Event()
 
-    async def task1():
+    async def task1() -> None:
         task1_finished.set()
 
-    async def task2():
+    async def task2() -> None:
         await task1_finished.wait()
         await anyio.sleep(0)  # pragma: nocover
         task2_finished.set()  # pragma: nocover
@@ -28,7 +32,9 @@ async def test_run_until_first_complete():
     assert not task2_finished.is_set()
 
 
-def test_accessing_context_from_threaded_sync_endpoint(test_client_factory) -> None:
+def test_accessing_context_from_threaded_sync_endpoint(
+    test_client_factory: TestClientFactory,
+) -> None:
     ctxvar: ContextVar[bytes] = ContextVar("ctxvar")
     ctxvar.set(b"data")
 
@@ -45,7 +51,7 @@ def test_accessing_context_from_threaded_sync_endpoint(test_client_factory) -> N
 @pytest.mark.anyio
 async def test_iterate_in_threadpool() -> None:
     class CustomIterable:
-        def __iter__(self):
+        def __iter__(self) -> Iterator[int]:
             yield from range(3)
 
     assert [v async for v in iterate_in_threadpool(CustomIterable())] == [0, 1, 2]