+from __future__ import annotations
+
import contextlib
import inspect
import io
import json
import math
import queue
+import sys
import typing
import warnings
from concurrent.futures import Future
from urllib.parse import unquote, urljoin
import anyio
+import anyio.abc
import anyio.from_thread
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.streams.stapled import StapledObjectStream
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect
+if sys.version_info >= (3, 10): # pragma: no cover
+ from typing import TypeGuard
+else: # pragma: no cover
+ from typing_extensions import TypeGuard
+
try:
import httpx
except ModuleNotFoundError: # pragma: no cover
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]
-def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
+def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> TypeGuard[ASGI3App]:
if inspect.isclass(app):
return hasattr(app, "__await__")
return is_async_callable(app)
class _Upgrade(Exception):
- def __init__(self, session: "WebSocketTestSession") -> None:
+ def __init__(self, session: WebSocketTestSession) -> None:
self.session = session
self.scope = scope
self.accepted_subprotocol = None
self.portal_factory = portal_factory
- self._receive_queue: "queue.Queue[Message]" = queue.Queue()
- self._send_queue: "queue.Queue[Message | BaseException]" = queue.Queue()
+ self._receive_queue: queue.Queue[Message] = queue.Queue()
+ self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
self.extra_headers = None
- def __enter__(self) -> "WebSocketTestSession":
+ def __enter__(self) -> WebSocketTestSession:
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(self.portal_factory())
+ self.should_close = anyio.Event()
try:
- _: "Future[None]" = self.portal.start_task_soon(self._run)
+ _: Future[None] = self.portal.start_task_soon(self._run)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
self.extra_headers = message.get("headers", None)
return self
+ async def _notify_close(self) -> None:
+ self.should_close.set()
+
def __exit__(self, *args: typing.Any) -> None:
try:
self.close(1000)
finally:
+ self.portal.start_task_soon(self._notify_close)
self.exit_stack.close()
while not self._send_queue.empty():
message = self._send_queue.get()
"""
The sub-thread in which the websocket session runs.
"""
- scope = self.scope
- receive = self._asgi_receive
- send = self._asgi_send
- try:
- await self.app(scope, receive, send)
- except BaseException as exc:
- self._send_queue.put(exc)
- raise
+
+ async def run_app(tg: anyio.abc.TaskGroup) -> None:
+ try:
+ await self.app(self.scope, self._asgi_receive, self._asgi_send)
+ except anyio.get_cancelled_exc_class():
+ ...
+ except BaseException as exc:
+ self._send_queue.put(exc)
+ raise
+ finally:
+ tg.cancel_scope.cancel()
+
+ async with anyio.create_task_group() as tg:
+ tg.start_soon(run_app, tg)
+ await self.should_close.wait()
+ tg.cancel_scope.cancel()
async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
else:
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
- def close(self, code: int = 1000, reason: typing.Union[str, None] = None) -> None:
+ def close(self, code: int = 1000, reason: str | None = None) -> None:
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
def receive(self) -> Message:
self._raise_on_close(message)
return typing.cast(bytes, message["bytes"])
- def receive_json(self, mode: str = "text") -> typing.Any:
- assert mode in ["text", "binary"]
+ def receive_json(
+ self, mode: typing.Literal["text", "binary"] = "text"
+ ) -> typing.Any:
message = self.receive()
self._raise_on_close(message)
if mode == "text":
raise_server_exceptions: bool = True,
root_path: str = "",
*,
- app_state: typing.Dict[str, typing.Any],
+ app_state: dict[str, typing.Any],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
# Include the 'host' header.
if "host" in request.headers:
- headers: typing.List[typing.Tuple[bytes, bytes]] = []
+ headers: list[tuple[bytes, bytes]] = []
elif port == default_port: # pragma: no cover
headers = [(b"host", host.encode())]
else: # pragma: no cover
for key, value in request.headers.multi_items()
]
- scope: typing.Dict[str, typing.Any]
+ scope: dict[str, typing.Any]
if scheme in {"ws", "wss"}:
subprotocol = request.headers.get("sec-websocket-protocol", None)
request_complete = False
response_started = False
response_complete: anyio.Event
- raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
+ raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
template = None
context = None
class TestClient(httpx.Client):
__test__ = False
- task: "Future[None]"
- portal: typing.Optional[anyio.abc.BlockingPortal] = None
+ task: Future[None]
+ portal: anyio.abc.BlockingPortal | None = None
def __init__(
self,
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
- backend: str = "asyncio",
- backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
- cookies: httpx._types.CookieTypes = None,
- headers: typing.Dict[str, str] = None,
+ backend: typing.Literal["asyncio", "trio"] = "asyncio",
+ backend_options: typing.Dict[str, typing.Any] | None = None,
+ cookies: httpx._types.CookieTypes | None = None,
+ headers: typing.Dict[str, str] | None = None,
follow_redirects: bool = True,
) -> None:
self.async_backend = _AsyncBackend(
backend=backend, backend_options=backend_options or {}
)
if _is_asgi3(app):
- app = typing.cast(ASGI3App, app)
asgi_app = app
else:
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
yield portal
def _choose_redirect_arg(
- self,
- follow_redirects: typing.Optional[bool],
- allow_redirects: typing.Optional[bool],
- ) -> typing.Union[bool, httpx._client.UseClientDefault]:
- redirect: typing.Union[
- bool, httpx._client.UseClientDefault
- ] = httpx._client.USE_CLIENT_DEFAULT
+ self, follow_redirects: bool | None, allow_redirects: bool | None
+ ) -> bool | httpx._client.UseClientDefault:
+ redirect: bool | httpx._client.UseClientDefault = (
+ httpx._client.USE_CLIENT_DEFAULT
+ )
if allow_redirects is not None:
message = (
"The `allow_redirects` argument is deprecated. "
)
def websocket_connect(
- self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
+ self,
+ url: str,
+ subprotocols: typing.Sequence[str] | None = None,
+ **kwargs: typing.Any,
) -> "WebSocketTestSession":
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
from typing import Callable
import anyio
+import anyio.lowlevel
import pytest
import sniffio
import trio.lowlevel
from starlette.responses import JSONResponse, RedirectResponse, Response
from starlette.routing import Route
from starlette.testclient import TestClient
+from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect
-def mock_service_endpoint(request):
+def mock_service_endpoint(request: Request):
return JSONResponse({"mock": "example"})
-mock_service = Starlette(
- routes=[
- Route("/", endpoint=mock_service_endpoint),
- ]
-)
+mock_service = Starlette(routes=[Route("/", endpoint=mock_service_endpoint)])
def current_task():
raise RuntimeError()
-def test_use_testclient_in_endpoint(test_client_factory):
+def test_use_testclient_in_endpoint(test_client_factory: Callable[..., TestClient]):
"""
We should be able to use the test client within applications.
during tests or in development.
"""
- def homepage(request):
+ def homepage(request: Request):
client = test_client_factory(mock_service)
response = client.get("/")
return JSONResponse(response.json())
assert client.headers.get("Authentication") == "Bearer 123"
-def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name):
+def test_use_testclient_as_contextmanager(
+ test_client_factory: Callable[..., TestClient], anyio_backend_name: str
+):
"""
This test asserts a number of properties that are important for an
app level task_group
shutdown_loop = None
@asynccontextmanager
- async def lifespan_context(app):
+ async def lifespan_context(app: Starlette):
nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop
startup_task = current_task()
startup_loop = get_identity()
- async with anyio.create_task_group() as app.task_group:
+ async with anyio.create_task_group():
yield
shutdown_task = current_task()
shutdown_loop = get_identity()
- async def loop_id(request):
+ async def loop_id(request: Request):
return JSONResponse(get_identity())
app = Starlette(
assert first_task is not startup_task
-def test_error_on_startup(test_client_factory):
+def test_error_on_startup(test_client_factory: Callable[..., TestClient]):
with pytest.deprecated_call(
match="The on_startup and on_shutdown parameters are deprecated"
):
pass # pragma: no cover
-def test_exception_in_middleware(test_client_factory):
+def test_exception_in_middleware(test_client_factory: Callable[..., TestClient]):
class MiddlewareException(Exception):
pass
class BrokenMiddleware:
- def __init__(self, app):
+ def __init__(self, app: ASGIApp):
self.app = app
- async def __call__(self, scope, receive, send):
+ async def __call__(self, scope: Scope, receive: Receive, send: Send):
raise MiddlewareException()
broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)])
pass # pragma: no cover
-def test_testclient_asgi2(test_client_factory):
- def app(scope):
- async def inner(receive, send):
+def test_testclient_asgi2(test_client_factory: Callable[..., TestClient]):
+ def app(scope: Scope):
+ async def inner(receive: Receive, send: Send):
await send(
{
"type": "http.response.start",
assert response.text == "Hello, world!"
-def test_testclient_asgi3(test_client_factory):
- async def app(scope, receive, send):
+def test_testclient_asgi3(test_client_factory: Callable[..., TestClient]):
+ async def app(scope: Scope, receive: Receive, send: Send):
await send(
{
"type": "http.response.start",
assert response.text == "Hello, world!"
-def test_websocket_blocking_receive(test_client_factory):
- def app(scope):
- async def respond(websocket):
+def test_websocket_blocking_receive(test_client_factory: Callable[..., TestClient]):
+ def app(scope: Scope):
+ async def respond(websocket: WebSocket):
await websocket.send_json({"message": "test"})
- async def asgi(receive, send):
+ async def asgi(receive: Receive, send: Send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
async with anyio.create_task_group() as task_group:
assert data == {"message": "test"}
+def test_websocket_not_block_on_close(test_client_factory: Callable[..., TestClient]):
+ def app(scope: Scope):
+ async def asgi(receive: Receive, send: Send):
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ while True:
+ await anyio.sleep(0.1)
+
+ return asgi
+
+ client = test_client_factory(app)
+ with client.websocket_connect("/") as websocket:
+ ...
+ assert websocket.should_close.is_set()
+
+
@pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà"))
-def test_query_params(test_client_factory, param: str):
- def homepage(request):
+def test_query_params(test_client_factory: Callable[..., TestClient], param: str):
+ def homepage(request: Request):
return Response(request.query_params["param"])
app = Starlette(routes=[Route("/", endpoint=homepage)])
("example.com", False),
],
)
-def test_domain_restricted_cookies(test_client_factory, domain, ok):
+def test_domain_restricted_cookies(
+ test_client_factory: Callable[..., TestClient], domain: str, ok: bool
+):
"""
Test that test client discards domain restricted cookies which do not match the
base_url of the testclient (`http://testserver` by default).
in accordance with RFC 2965.
"""
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send):
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie(
"mycookie",
assert cookie_set == ok
-def test_forward_follow_redirects(test_client_factory):
- async def app(scope, receive, send):
+def test_forward_follow_redirects(test_client_factory: Callable[..., TestClient]):
+ async def app(scope: Scope, receive: Receive, send: Send):
if "/ok" in scope["path"]:
response = Response("ok")
else:
assert response.status_code == 200
-def test_forward_nofollow_redirects(test_client_factory):
- async def app(scope, receive, send):
+def test_forward_nofollow_redirects(test_client_factory: Callable[..., TestClient]):
+ async def app(scope: Scope, receive: Receive, send: Send):
response = RedirectResponse("/ok")
await response(scope, receive, send)