From 48ae5ccc0ef0853d3887eeecc4d3a6b352df9474 Mon Sep 17 00:00:00 2001 From: Scirlat Danut Date: Sun, 4 Feb 2024 18:48:26 +0200 Subject: [PATCH] Add type hints to `test_background.py` (#2473) Co-authored-by: Scirlat Danut --- tests/test_background.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/test_background.py b/tests/test_background.py index 39a126dc..846deecf 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -5,18 +5,21 @@ import pytest from starlette.background import BackgroundTask, BackgroundTasks from starlette.responses import Response from starlette.testclient import TestClient +from starlette.types import Receive, Scope, Send +TestClientFactory = Callable[..., TestClient] -def test_async_task(test_client_factory): + +def test_async_task(test_client_factory: TestClientFactory) -> None: TASK_COMPLETE = False - async def async_task(): + async def async_task() -> None: nonlocal TASK_COMPLETE TASK_COMPLETE = True task = BackgroundTask(async_task) - async def app(scope, receive, send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) @@ -26,16 +29,16 @@ def test_async_task(test_client_factory): assert TASK_COMPLETE -def test_sync_task(test_client_factory): +def test_sync_task(test_client_factory: TestClientFactory) -> None: TASK_COMPLETE = False - def sync_task(): + def sync_task() -> None: nonlocal TASK_COMPLETE TASK_COMPLETE = True task = BackgroundTask(sync_task) - async def app(scope, receive, send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) @@ -45,14 +48,14 @@ def test_sync_task(test_client_factory): assert TASK_COMPLETE -def test_multiple_tasks(test_client_factory: Callable[..., TestClient]): +def test_multiple_tasks(test_client_factory: TestClientFactory) -> None: TASK_COUNTER = 0 - def increment(amount): + def increment(amount: int) -> None: nonlocal TASK_COUNTER TASK_COUNTER += amount - async def app(scope, receive, send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: tasks = BackgroundTasks() tasks.add_task(increment, amount=1) tasks.add_task(increment, amount=2) @@ -69,17 +72,17 @@ def test_multiple_tasks(test_client_factory: Callable[..., TestClient]): def test_multi_tasks_failure_avoids_next_execution( - test_client_factory: Callable[..., TestClient], + test_client_factory: TestClientFactory, ) -> None: TASK_COUNTER = 0 - def increment(): + def increment() -> None: nonlocal TASK_COUNTER TASK_COUNTER += 1 if TASK_COUNTER == 1: raise Exception("task failed") - async def app(scope, receive, send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: tasks = BackgroundTasks() tasks.add_task(increment) tasks.add_task(increment) -- 2.47.2