]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_background.py` (#2473)
authorScirlat Danut <danut.scirlat@gmail.com>
Sun, 4 Feb 2024 16:48:26 +0000 (18:48 +0200)
committerGitHub <noreply@github.com>
Sun, 4 Feb 2024 16:48:26 +0000 (16:48 +0000)
Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
tests/test_background.py

index 39a126dc9d59393369326be2bc5e81243a986a08..846deecfd9a6a1fe760cb10053f46297f662205e 100644 (file)
@@ -5,18 +5,21 @@ import pytest
 from starlette.background import BackgroundTask, BackgroundTasks
 from starlette.responses import Response
 from starlette.testclient import TestClient
+from starlette.types import Receive, Scope, Send
 
+TestClientFactory = Callable[..., TestClient]
 
-def test_async_task(test_client_factory):
+
+def test_async_task(test_client_factory: TestClientFactory) -> None:
     TASK_COMPLETE = False
 
-    async def async_task():
+    async def async_task() -> None:
         nonlocal TASK_COMPLETE
         TASK_COMPLETE = True
 
     task = BackgroundTask(async_task)
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         response = Response("task initiated", media_type="text/plain", background=task)
         await response(scope, receive, send)
 
@@ -26,16 +29,16 @@ def test_async_task(test_client_factory):
     assert TASK_COMPLETE
 
 
-def test_sync_task(test_client_factory):
+def test_sync_task(test_client_factory: TestClientFactory) -> None:
     TASK_COMPLETE = False
 
-    def sync_task():
+    def sync_task() -> None:
         nonlocal TASK_COMPLETE
         TASK_COMPLETE = True
 
     task = BackgroundTask(sync_task)
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         response = Response("task initiated", media_type="text/plain", background=task)
         await response(scope, receive, send)
 
@@ -45,14 +48,14 @@ def test_sync_task(test_client_factory):
     assert TASK_COMPLETE
 
 
-def test_multiple_tasks(test_client_factory: Callable[..., TestClient]):
+def test_multiple_tasks(test_client_factory: TestClientFactory) -> None:
     TASK_COUNTER = 0
 
-    def increment(amount):
+    def increment(amount: int) -> None:
         nonlocal TASK_COUNTER
         TASK_COUNTER += amount
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         tasks = BackgroundTasks()
         tasks.add_task(increment, amount=1)
         tasks.add_task(increment, amount=2)
@@ -69,17 +72,17 @@ def test_multiple_tasks(test_client_factory: Callable[..., TestClient]):
 
 
 def test_multi_tasks_failure_avoids_next_execution(
-    test_client_factory: Callable[..., TestClient],
+    test_client_factory: TestClientFactory,
 ) -> None:
     TASK_COUNTER = 0
 
-    def increment():
+    def increment() -> None:
         nonlocal TASK_COUNTER
         TASK_COUNTER += 1
         if TASK_COUNTER == 1:
             raise Exception("task failed")
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         tasks = BackgroundTasks()
         tasks.add_task(increment)
         tasks.add_task(increment)