import asyncio
+import contextlib
import functools
import inspect
import re
+import sys
import traceback
+import types
import typing
+import warnings
from enum import Enum
from starlette.concurrency import run_in_threadpool
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose
+if sys.version_info >= (3, 7):
+ from contextlib import asynccontextmanager # pragma: no cover
+else:
+ from contextlib2 import asynccontextmanager # pragma: no cover
+
class NoMatchFound(Exception):
"""
)
+_T = typing.TypeVar("_T")
+
+
+class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
+ def __init__(self, cm: typing.ContextManager[_T]):
+ self._cm = cm
+
+ async def __aenter__(self) -> _T:
+ return self._cm.__enter__()
+
+ async def __aexit__(
+ self,
+ exc_type: typing.Optional[typing.Type[BaseException]],
+ exc_value: typing.Optional[BaseException],
+ traceback: typing.Optional[types.TracebackType],
+ ) -> typing.Optional[bool]:
+ return self._cm.__exit__(exc_type, exc_value, traceback)
+
+
+def _wrap_gen_lifespan_context(
+ lifespan_context: typing.Callable[[typing.Any], typing.Generator]
+) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
+ cmgr = contextlib.contextmanager(lifespan_context)
+
+ @functools.wraps(cmgr)
+ def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
+ return _AsyncLiftContextManager(cmgr(app))
+
+ return wrapper
+
+
+class _DefaultLifespan:
+ def __init__(self, router: "Router"):
+ self._router = router
+
+ async def __aenter__(self) -> None:
+ await self._router.startup()
+
+ async def __aexit__(self, *exc_info: object) -> None:
+ await self._router.shutdown()
+
+ def __call__(self: _T, app: object) -> _T:
+ return self
+
+
class Router:
def __init__(
self,
default: ASGIApp = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
- lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None,
+ lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
self.on_startup = [] if on_startup is None else list(on_startup)
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
- async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator:
- await self.startup()
- yield
- await self.shutdown()
+ if lifespan is None:
+ self.lifespan_context: typing.Callable[
+ [typing.Any], typing.AsyncContextManager
+ ] = _DefaultLifespan(self)
- self.lifespan_context = default_lifespan if lifespan is None else lifespan
+ elif inspect.isasyncgenfunction(lifespan):
+ warnings.warn(
+ "async generator function lifespans are deprecated, "
+ "use an @contextlib.asynccontextmanager function instead",
+ DeprecationWarning,
+ )
+ self.lifespan_context = asynccontextmanager(
+ lifespan, # type: ignore[arg-type]
+ )
+ elif inspect.isgeneratorfunction(lifespan):
+ warnings.warn(
+ "generator function lifespans are deprecated, "
+ "use an @contextlib.asynccontextmanager function instead",
+ DeprecationWarning,
+ )
+ self.lifespan_context = _wrap_gen_lifespan_context(
+ lifespan, # type: ignore[arg-type]
+ )
+ else:
+ self.lifespan_context = lifespan
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
Handle ASGI lifespan messages, which allows us to manage application
startup and shutdown events.
"""
- first = True
+ started = False
app = scope.get("app")
await receive()
try:
- if inspect.isasyncgenfunction(self.lifespan_context):
- async for item in self.lifespan_context(app):
- assert first, "Lifespan context yielded multiple times."
- first = False
- await send({"type": "lifespan.startup.complete"})
- await receive()
- else:
- for item in self.lifespan_context(app): # type: ignore
- assert first, "Lifespan context yielded multiple times."
- first = False
- await send({"type": "lifespan.startup.complete"})
- await receive()
+ async with self.lifespan_context(app):
+ await send({"type": "lifespan.startup.complete"})
+ started = True
+ await receive()
except BaseException:
- if first:
- exc_text = traceback.format_exc()
+ exc_text = traceback.format_exc()
+ if started:
+ await send({"type": "lifespan.shutdown.failed", "message": exc_text})
+ else:
await send({"type": "lifespan.startup.failed", "message": exc_text})
raise
else:
async def wait_startup(self) -> None:
await self.stream_receive.send({"type": "lifespan.startup"})
- message = await self.stream_send.receive()
- if message is None:
- self.task.result()
+
+ async def receive() -> typing.Any:
+ message = await self.stream_send.receive()
+ if message is None:
+ self.task.result()
+ return message
+
+ message = await receive()
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
)
if message["type"] == "lifespan.startup.failed":
+ await receive()
+
+ async def wait_shutdown(self) -> None:
+ async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
+ return message
- async def wait_shutdown(self) -> None:
async with self.stream_send:
await self.stream_receive.send({"type": "lifespan.shutdown"})
- message = await self.stream_send.receive()
- if message is None:
- self.task.result()
- assert message["type"] == "lifespan.shutdown.complete"
+ message = await receive()
+ assert message["type"] in (
+ "lifespan.shutdown.complete",
+ "lifespan.shutdown.failed",
+ )
+ if message["type"] == "lifespan.shutdown.failed":
+ await receive()
import os
+import sys
import pytest
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
+if sys.version_info >= (3, 7):
+ from contextlib import asynccontextmanager # pragma: no cover
+else:
+ from contextlib2 import asynccontextmanager # pragma: no cover
+
app = Starlette()
assert cleanup_complete
-def test_app_async_lifespan(test_client_factory):
+def test_app_async_cm_lifespan(test_client_factory):
+ startup_complete = False
+ cleanup_complete = False
+
+ @asynccontextmanager
+ async def lifespan(app):
+ nonlocal startup_complete, cleanup_complete
+ startup_complete = True
+ yield
+ cleanup_complete = True
+
+ app = Starlette(lifespan=lifespan)
+
+ assert not startup_complete
+ assert not cleanup_complete
+ with test_client_factory(app):
+ assert startup_complete
+ assert not cleanup_complete
+ assert startup_complete
+ assert cleanup_complete
+
+
+deprecated_lifespan = pytest.mark.filterwarnings(
+ r"ignore"
+ r":(async )?generator function lifespans are deprecated, use an "
+ r"@contextlib\.asynccontextmanager function instead"
+ r":DeprecationWarning"
+ r":starlette.routing"
+)
+
+
+@deprecated_lifespan
+def test_app_async_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False
assert cleanup_complete
-def test_app_sync_lifespan(test_client_factory):
+@deprecated_lifespan
+def test_app_sync_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False
from starlette.responses import JSONResponse
from starlette.websockets import WebSocket, WebSocketDisconnect
-if sys.version_info >= (3, 7):
- from asyncio import current_task as asyncio_current_task # pragma: no cover
-else:
- asyncio_current_task = asyncio.Task.current_task # pragma: no cover
+if sys.version_info >= (3, 7): # pragma: no cover
+ from asyncio import current_task as asyncio_current_task
+ from contextlib import asynccontextmanager
+else: # pragma: no cover
+ asyncio_current_task = asyncio.Task.current_task
+ from contextlib2 import asynccontextmanager
mock_service = Starlette()
shutdown_task = object()
shutdown_loop = None
+ @asynccontextmanager
async def lifespan_context(app):
nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop