from concurrent.futures import Future
from urllib.parse import unquote, urljoin, urlsplit
-import anyio
+import anyio.abc
import requests
from anyio.streams.stapled import StapledObjectStream
else: # pragma: no cover
from typing_extensions import TypedDict
+
+_PortalFactoryType = typing.Callable[
+ [], typing.ContextManager[anyio.abc.BlockingPortal]
+]
+
+
# Annotations for `Session.request()`
Cookies = typing.Union[
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
def __init__(
self,
app: ASGI3App,
- async_backend: _AsyncBackend,
+ portal_factory: _PortalFactoryType,
raise_server_exceptions: bool = True,
root_path: str = "",
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
- self.async_backend = async_backend
+ self.portal_factory = portal_factory
def send(
self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
"server": [host, port],
"subprotocols": subprotocols,
}
- session = WebSocketTestSession(self.app, scope, self.async_backend)
+ session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
scope = {
context = message["context"]
try:
- with anyio.start_blocking_portal(**self.async_backend) as portal:
+ with self.portal_factory() as portal:
response_complete = portal.call(anyio.Event)
portal.call(self.app, scope, receive, send)
except BaseException as exc:
self,
app: ASGI3App,
scope: Scope,
- async_backend: _AsyncBackend,
+ portal_factory: _PortalFactoryType,
) -> None:
self.app = app
self.scope = scope
self.accepted_subprotocol = None
- self.async_backend = async_backend
+ self.portal_factory = portal_factory
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
def __enter__(self) -> "WebSocketTestSession":
self.exit_stack = contextlib.ExitStack()
- self.portal = self.exit_stack.enter_context(
- anyio.start_blocking_portal(**self.async_backend)
- )
+ self.portal = self.exit_stack.enter_context(self.portal_factory())
try:
_: "Future[None]" = self.portal.start_task_soon(self._run)
class TestClient(requests.Session):
__test__ = False # For pytest to not discover this up.
task: "Future[None]"
+ portal: typing.Optional[anyio.abc.BlockingPortal] = None
def __init__(
self,
asgi_app = _WrapASGI2(app) # type: ignore
adapter = _ASGIAdapter(
asgi_app,
- self.async_backend,
+ portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
)
self.app = asgi_app
self.base_url = base_url
+ @contextlib.contextmanager
+ def _portal_factory(
+ self,
+ ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
+ if self.portal is not None:
+ yield self.portal
+ else:
+ with anyio.start_blocking_portal(**self.async_backend) as portal:
+ yield portal
+
def request( # type: ignore
self,
method: str,
return session
def __enter__(self) -> "TestClient":
- self.exit_stack = contextlib.ExitStack()
- self.portal = self.exit_stack.enter_context(
- anyio.start_blocking_portal(**self.async_backend)
- )
- self.stream_send = StapledObjectStream(
- *anyio.create_memory_object_stream(math.inf)
- )
- self.stream_receive = StapledObjectStream(
- *anyio.create_memory_object_stream(math.inf)
- )
- try:
- self.task = self.portal.start_task_soon(self.lifespan)
- self.portal.call(self.wait_startup)
- except Exception:
- self.exit_stack.close()
- raise
+ with contextlib.ExitStack() as stack:
+ self.portal = portal = stack.enter_context(
+ anyio.start_blocking_portal(**self.async_backend)
+ )
+
+ @stack.callback
+ def reset_portal() -> None:
+ self.portal = None
+
+ self.stream_send = StapledObjectStream(
+ *anyio.create_memory_object_stream(math.inf)
+ )
+ self.stream_receive = StapledObjectStream(
+ *anyio.create_memory_object_stream(math.inf)
+ )
+ self.task = portal.start_task_soon(self.lifespan)
+ portal.call(self.wait_startup)
+
+ @stack.callback
+ def wait_shutdown() -> None:
+ portal.call(self.wait_shutdown)
+
+ self.exit_stack = stack.pop_all()
+
return self
def __exit__(self, *args: typing.Any) -> None:
- try:
- self.portal.call(self.wait_shutdown)
- finally:
- self.exit_stack.close()
+ self.exit_stack.close()
async def lifespan(self) -> None:
scope = {"type": "lifespan"}
+import asyncio
+import itertools
+import sys
+
import anyio
import pytest
+import sniffio
+import trio.lowlevel
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import JSONResponse
from starlette.websockets import WebSocket, WebSocketDisconnect
+if sys.version_info >= (3, 7):
+ from asyncio import current_task as asyncio_current_task # pragma: no cover
+else:
+ asyncio_current_task = asyncio.Task.current_task # pragma: no cover
+
mock_service = Starlette()
return JSONResponse({"mock": "example"})
-def create_app(test_client_factory):
- app = Starlette()
-
- @app.route("/")
- def homepage(request):
- client = test_client_factory(mock_service)
- response = client.get("/")
- return JSONResponse(response.json())
+def current_task():
+ # anyio's TaskInfo comparisons are invalid after their associated native
+ # task object is GC'd https://github.com/agronholm/anyio/issues/324
+ asynclib_name = sniffio.current_async_library()
+ if asynclib_name == "trio":
+ return trio.lowlevel.current_task()
- return app
+ if asynclib_name == "asyncio":
+ task = asyncio_current_task()
+ if task is None:
+ raise RuntimeError("must be called from a running task") # pragma: no cover
+ return task
+ raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover
startup_error_app = Starlette()
This is useful if we need to mock out other services,
during tests or in development.
"""
- client = test_client_factory(create_app(test_client_factory))
+
+ app = Starlette()
+
+ @app.route("/")
+ def homepage(request):
+ client = test_client_factory(mock_service)
+ response = client.get("/")
+ return JSONResponse(response.json())
+
+ client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"mock": "example"}
-def test_use_testclient_as_contextmanager(test_client_factory):
- with test_client_factory(create_app(test_client_factory)):
- pass
+def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name):
+ """
+ This test asserts a number of properties that are important for an
+ app level task_group
+ """
+ counter = itertools.count()
+ identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar")
+
+ def get_identity():
+ try:
+ return identity_runvar.get()
+ except LookupError:
+ token = next(counter)
+ identity_runvar.set(token)
+ return token
+
+ startup_task = object()
+ startup_loop = None
+ shutdown_task = object()
+ shutdown_loop = None
+
+ async def lifespan_context(app):
+ nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop
+
+ startup_task = current_task()
+ startup_loop = get_identity()
+ async with anyio.create_task_group() as app.task_group:
+ yield
+ shutdown_task = current_task()
+ shutdown_loop = get_identity()
+
+ app = Starlette(lifespan=lifespan_context)
+
+ @app.route("/loop_id")
+ async def loop_id(request):
+ return JSONResponse(get_identity())
+
+ client = test_client_factory(app)
+
+ with client:
+ # within a TestClient context every async request runs in the same thread
+ assert client.get("/loop_id").json() == 0
+ assert client.get("/loop_id").json() == 0
+
+ # that thread is also the same as the lifespan thread
+ assert startup_loop == 0
+ assert shutdown_loop == 0
+
+ # lifespan events run in the same task, this is important because a task
+ # group must be entered and exited in the same task.
+ assert startup_task is shutdown_task
+
+ # outside the TestClient context, new requests continue to spawn in new
+ # eventloops in new threads
+ assert client.get("/loop_id").json() == 1
+ assert client.get("/loop_id").json() == 2
+
+ first_task = startup_task
+
+ with client:
+ # the TestClient context can be re-used, starting a new lifespan task
+ # in a new thread
+ assert client.get("/loop_id").json() == 3
+ assert client.get("/loop_id").json() == 3
+
+ assert startup_loop == 3
+ assert shutdown_loop == 3
+
+ # lifespan events still run in the same task, with the context but...
+ assert startup_task is shutdown_task
+
+ # ... the second TestClient context creates a new lifespan task.
+ assert first_task is not startup_task
def test_error_on_startup(test_client_factory):