]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
use an async context manager factory for lifespan (#1227)
authorThomas Grainger <tagrain@gmail.com>
Sat, 3 Jul 2021 17:43:24 +0000 (18:43 +0100)
committerGitHub <noreply@github.com>
Sat, 3 Jul 2021 17:43:24 +0000 (18:43 +0100)
setup.py
starlette/applications.py
starlette/routing.py
starlette/testclient.py
tests/test_applications.py
tests/test_testclient.py

index ac6479746fdd966a64a3f60d85f4db27879bef25..31789fe09dc8e62d554dc5da124a810e9844e72b 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -40,6 +40,7 @@ setup(
     install_requires=[
         "anyio>=3.0.0,<4",
         "typing_extensions; python_version < '3.8'",
+        "contextlib2 >= 21.6.0; python_version < '3.7'",
     ],
     extras_require={
         "full": [
index 34c3e38bd98fc7b8b729fa9506cfc7702bc5cf9c..ea52ee70eea462a27df46eb9f15652a4037a919e 100644 (file)
@@ -46,7 +46,7 @@ class Starlette:
         ] = None,
         on_startup: typing.Sequence[typing.Callable] = None,
         on_shutdown: typing.Sequence[typing.Callable] = None,
-        lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None,
+        lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None,
     ) -> None:
         # The lifespan context function is a newer style that replaces
         # on_startup / on_shutdown handlers. Use one or the other, not both.
index cef1ef4848fbb4ad95de4e3c96372f8b05eb59f5..9a1a5e12df719775408a3454096e5147ebad4265 100644 (file)
@@ -1,9 +1,13 @@
 import asyncio
+import contextlib
 import functools
 import inspect
 import re
+import sys
 import traceback
+import types
 import typing
+import warnings
 from enum import Enum
 
 from starlette.concurrency import run_in_threadpool
@@ -15,6 +19,11 @@ from starlette.responses import PlainTextResponse, RedirectResponse
 from starlette.types import ASGIApp, Receive, Scope, Send
 from starlette.websockets import WebSocket, WebSocketClose
 
+if sys.version_info >= (3, 7):
+    from contextlib import asynccontextmanager  # pragma: no cover
+else:
+    from contextlib2 import asynccontextmanager  # pragma: no cover
+
 
 class NoMatchFound(Exception):
     """
@@ -470,6 +479,51 @@ class Host(BaseRoute):
         )
 
 
+_T = typing.TypeVar("_T")
+
+
+class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
+    def __init__(self, cm: typing.ContextManager[_T]):
+        self._cm = cm
+
+    async def __aenter__(self) -> _T:
+        return self._cm.__enter__()
+
+    async def __aexit__(
+        self,
+        exc_type: typing.Optional[typing.Type[BaseException]],
+        exc_value: typing.Optional[BaseException],
+        traceback: typing.Optional[types.TracebackType],
+    ) -> typing.Optional[bool]:
+        return self._cm.__exit__(exc_type, exc_value, traceback)
+
+
+def _wrap_gen_lifespan_context(
+    lifespan_context: typing.Callable[[typing.Any], typing.Generator]
+) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
+    cmgr = contextlib.contextmanager(lifespan_context)
+
+    @functools.wraps(cmgr)
+    def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
+        return _AsyncLiftContextManager(cmgr(app))
+
+    return wrapper
+
+
+class _DefaultLifespan:
+    def __init__(self, router: "Router"):
+        self._router = router
+
+    async def __aenter__(self) -> None:
+        await self._router.startup()
+
+    async def __aexit__(self, *exc_info: object) -> None:
+        await self._router.shutdown()
+
+    def __call__(self: _T, app: object) -> _T:
+        return self
+
+
 class Router:
     def __init__(
         self,
@@ -478,7 +532,7 @@ class Router:
         default: ASGIApp = None,
         on_startup: typing.Sequence[typing.Callable] = None,
         on_shutdown: typing.Sequence[typing.Callable] = None,
-        lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None,
+        lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None,
     ) -> None:
         self.routes = [] if routes is None else list(routes)
         self.redirect_slashes = redirect_slashes
@@ -486,12 +540,31 @@ class Router:
         self.on_startup = [] if on_startup is None else list(on_startup)
         self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
 
-        async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator:
-            await self.startup()
-            yield
-            await self.shutdown()
+        if lifespan is None:
+            self.lifespan_context: typing.Callable[
+                [typing.Any], typing.AsyncContextManager
+            ] = _DefaultLifespan(self)
 
-        self.lifespan_context = default_lifespan if lifespan is None else lifespan
+        elif inspect.isasyncgenfunction(lifespan):
+            warnings.warn(
+                "async generator function lifespans are deprecated, "
+                "use an @contextlib.asynccontextmanager function instead",
+                DeprecationWarning,
+            )
+            self.lifespan_context = asynccontextmanager(
+                lifespan,  # type: ignore[arg-type]
+            )
+        elif inspect.isgeneratorfunction(lifespan):
+            warnings.warn(
+                "generator function lifespans are deprecated, "
+                "use an @contextlib.asynccontextmanager function instead",
+                DeprecationWarning,
+            )
+            self.lifespan_context = _wrap_gen_lifespan_context(
+                lifespan,  # type: ignore[arg-type]
+            )
+        else:
+            self.lifespan_context = lifespan
 
     async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
         if scope["type"] == "websocket":
@@ -541,25 +614,19 @@ class Router:
         Handle ASGI lifespan messages, which allows us to manage application
         startup and shutdown events.
         """
-        first = True
+        started = False
         app = scope.get("app")
         await receive()
         try:
-            if inspect.isasyncgenfunction(self.lifespan_context):
-                async for item in self.lifespan_context(app):
-                    assert first, "Lifespan context yielded multiple times."
-                    first = False
-                    await send({"type": "lifespan.startup.complete"})
-                    await receive()
-            else:
-                for item in self.lifespan_context(app):  # type: ignore
-                    assert first, "Lifespan context yielded multiple times."
-                    first = False
-                    await send({"type": "lifespan.startup.complete"})
-                    await receive()
+            async with self.lifespan_context(app):
+                await send({"type": "lifespan.startup.complete"})
+                started = True
+                await receive()
         except BaseException:
-            if first:
-                exc_text = traceback.format_exc()
+            exc_text = traceback.format_exc()
+            if started:
+                await send({"type": "lifespan.shutdown.failed", "message": exc_text})
+            else:
                 await send({"type": "lifespan.startup.failed", "message": exc_text})
             raise
         else:
index 7aa59fb9e672629c3d53a7e8726852e0b2e8604d..08d03fa5c4297fc1d5322c9a3ff951e9de747abd 100644 (file)
@@ -543,22 +543,34 @@ class TestClient(requests.Session):
 
     async def wait_startup(self) -> None:
         await self.stream_receive.send({"type": "lifespan.startup"})
-        message = await self.stream_send.receive()
-        if message is None:
-            self.task.result()
+
+        async def receive() -> typing.Any:
+            message = await self.stream_send.receive()
+            if message is None:
+                self.task.result()
+            return message
+
+        message = await receive()
         assert message["type"] in (
             "lifespan.startup.complete",
             "lifespan.startup.failed",
         )
         if message["type"] == "lifespan.startup.failed":
+            await receive()
+
+    async def wait_shutdown(self) -> None:
+        async def receive() -> typing.Any:
             message = await self.stream_send.receive()
             if message is None:
                 self.task.result()
+            return message
 
-    async def wait_shutdown(self) -> None:
         async with self.stream_send:
             await self.stream_receive.send({"type": "lifespan.shutdown"})
-            message = await self.stream_send.receive()
-            if message is None:
-                self.task.result()
-            assert message["type"] == "lifespan.shutdown.complete"
+            message = await receive()
+            assert message["type"] in (
+                "lifespan.shutdown.complete",
+                "lifespan.shutdown.failed",
+            )
+            if message["type"] == "lifespan.shutdown.failed":
+                await receive()
index 6cb490696ddfe0f562ac03e390caf436c12c025a..f5f4c7fbea80dcb97f9d202c8a28f2baa60d79a8 100644 (file)
@@ -1,4 +1,5 @@
 import os
+import sys
 
 import pytest
 
@@ -10,6 +11,11 @@ from starlette.responses import JSONResponse, PlainTextResponse
 from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
 from starlette.staticfiles import StaticFiles
 
+if sys.version_info >= (3, 7):
+    from contextlib import asynccontextmanager  # pragma: no cover
+else:
+    from contextlib2 import asynccontextmanager  # pragma: no cover
+
 app = Starlette()
 
 
@@ -286,7 +292,39 @@ def test_app_add_event_handler(test_client_factory):
     assert cleanup_complete
 
 
-def test_app_async_lifespan(test_client_factory):
+def test_app_async_cm_lifespan(test_client_factory):
+    startup_complete = False
+    cleanup_complete = False
+
+    @asynccontextmanager
+    async def lifespan(app):
+        nonlocal startup_complete, cleanup_complete
+        startup_complete = True
+        yield
+        cleanup_complete = True
+
+    app = Starlette(lifespan=lifespan)
+
+    assert not startup_complete
+    assert not cleanup_complete
+    with test_client_factory(app):
+        assert startup_complete
+        assert not cleanup_complete
+    assert startup_complete
+    assert cleanup_complete
+
+
+deprecated_lifespan = pytest.mark.filterwarnings(
+    r"ignore"
+    r":(async )?generator function lifespans are deprecated, use an "
+    r"@contextlib\.asynccontextmanager function instead"
+    r":DeprecationWarning"
+    r":starlette.routing"
+)
+
+
+@deprecated_lifespan
+def test_app_async_gen_lifespan(test_client_factory):
     startup_complete = False
     cleanup_complete = False
 
@@ -307,7 +345,8 @@ def test_app_async_lifespan(test_client_factory):
     assert cleanup_complete
 
 
-def test_app_sync_lifespan(test_client_factory):
+@deprecated_lifespan
+def test_app_sync_gen_lifespan(test_client_factory):
     startup_complete = False
     cleanup_complete = False
 
index 57ea1c3dbfd8ec2f9c3ca9d7e2485a01176440e8..8c066678966c7274aa25839a8086d5e9d3f62a5c 100644 (file)
@@ -12,10 +12,12 @@ from starlette.middleware import Middleware
 from starlette.responses import JSONResponse
 from starlette.websockets import WebSocket, WebSocketDisconnect
 
-if sys.version_info >= (3, 7):
-    from asyncio import current_task as asyncio_current_task  # pragma: no cover
-else:
-    asyncio_current_task = asyncio.Task.current_task  # pragma: no cover
+if sys.version_info >= (3, 7):  # pragma: no cover
+    from asyncio import current_task as asyncio_current_task
+    from contextlib import asynccontextmanager
+else:  # pragma: no cover
+    asyncio_current_task = asyncio.Task.current_task
+    from contextlib2 import asynccontextmanager
 
 mock_service = Starlette()
 
@@ -90,6 +92,7 @@ def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_nam
     shutdown_task = object()
     shutdown_loop = None
 
+    @asynccontextmanager
     async def lifespan_context(app):
         nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop