+from __future__ import annotations
+
import itertools
import sys
-from asyncio import current_task as asyncio_current_task
+from asyncio import Task, current_task as asyncio_current_task
from contextlib import asynccontextmanager
-from typing import Callable
+from typing import Any, AsyncGenerator, Callable
import anyio
import anyio.lowlevel
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse, Response
from starlette.routing import Route
-from starlette.testclient import TestClient
+from starlette.testclient import ASGIInstance, TestClient
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect
+TestClientFactory = Callable[..., TestClient]
+
-def mock_service_endpoint(request: Request):
+def mock_service_endpoint(request: Request) -> JSONResponse:
return JSONResponse({"mock": "example"})
mock_service = Starlette(routes=[Route("/", endpoint=mock_service_endpoint)])
-def current_task():
+def current_task() -> Task[Any] | trio.lowlevel.Task:
# anyio's TaskInfo comparisons are invalid after their associated native
# task object is GC'd https://github.com/agronholm/anyio/issues/324
asynclib_name = sniffio.current_async_library()
raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover
-def startup():
+def startup() -> None:
raise RuntimeError()
-def test_use_testclient_in_endpoint(test_client_factory: Callable[..., TestClient]):
+def test_use_testclient_in_endpoint(test_client_factory: TestClientFactory) -> None:
"""
We should be able to use the test client within applications.
during tests or in development.
"""
- def homepage(request: Request):
+ def homepage(request: Request) -> JSONResponse:
client = test_client_factory(mock_service)
response = client.get("/")
return JSONResponse(response.json())
assert response.json() == {"mock": "example"}
-def test_testclient_headers_behavior():
+def test_testclient_headers_behavior() -> None:
"""
We should be able to use the test client with user defined headers.
def test_use_testclient_as_contextmanager(
- test_client_factory: Callable[..., TestClient], anyio_backend_name: str
-):
+ test_client_factory: TestClientFactory, anyio_backend_name: str
+) -> None:
"""
This test asserts a number of properties that are important for an
app level task_group
counter = itertools.count()
identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar")
- def get_identity():
+ def get_identity() -> int:
try:
return identity_runvar.get()
except LookupError:
shutdown_loop = None
@asynccontextmanager
- async def lifespan_context(app: Starlette):
+ async def lifespan_context(app: Starlette) -> AsyncGenerator[None, None]:
nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop
startup_task = current_task()
shutdown_task = current_task()
shutdown_loop = get_identity()
- async def loop_id(request: Request):
+ async def loop_id(request: Request) -> JSONResponse:
return JSONResponse(get_identity())
app = Starlette(
assert startup_task is shutdown_task
# outside the TestClient context, new requests continue to spawn in new
- # eventloops in new threads
+ # event loops in new threads
assert client.get("/loop_id").json() == 1
assert client.get("/loop_id").json() == 2
assert first_task is not startup_task
-def test_error_on_startup(test_client_factory: Callable[..., TestClient]):
+def test_error_on_startup(test_client_factory: TestClientFactory) -> None:
with pytest.deprecated_call(
match="The on_startup and on_shutdown parameters are deprecated"
):
pass # pragma: no cover
-def test_exception_in_middleware(test_client_factory: Callable[..., TestClient]):
+def test_exception_in_middleware(test_client_factory: TestClientFactory) -> None:
class MiddlewareException(Exception):
pass
def __init__(self, app: ASGIApp):
self.app = app
- async def __call__(self, scope: Scope, receive: Receive, send: Send):
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
raise MiddlewareException()
broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)])
pass # pragma: no cover
-def test_testclient_asgi2(test_client_factory: Callable[..., TestClient]):
- def app(scope: Scope):
- async def inner(receive: Receive, send: Send):
+def test_testclient_asgi2(test_client_factory: TestClientFactory) -> None:
+ def app(scope: Scope) -> ASGIInstance:
+ async def inner(receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
assert response.text == "Hello, world!"
-def test_testclient_asgi3(test_client_factory: Callable[..., TestClient]):
- async def app(scope: Scope, receive: Receive, send: Send):
+def test_testclient_asgi3(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
assert response.text == "Hello, world!"
-def test_websocket_blocking_receive(test_client_factory: Callable[..., TestClient]):
- def app(scope: Scope):
- async def respond(websocket: WebSocket):
+def test_websocket_blocking_receive(test_client_factory: TestClientFactory) -> None:
+ def app(scope: Scope) -> ASGIInstance:
+ async def respond(websocket: WebSocket) -> None:
await websocket.send_json({"message": "test"})
- async def asgi(receive: Receive, send: Send):
+ async def asgi(receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
async with anyio.create_task_group() as task_group:
assert data == {"message": "test"}
-def test_websocket_not_block_on_close(test_client_factory: Callable[..., TestClient]):
- def app(scope: Scope):
- async def asgi(receive: Receive, send: Send):
+def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None:
+ def app(scope: Scope) -> ASGIInstance:
+ async def asgi(receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
while True:
@pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà"))
-def test_query_params(test_client_factory: Callable[..., TestClient], param: str):
- def homepage(request: Request):
+def test_query_params(test_client_factory: TestClientFactory, param: str) -> None:
+ def homepage(request: Request) -> Response:
return Response(request.query_params["param"])
app = Starlette(routes=[Route("/", endpoint=homepage)])
],
)
def test_domain_restricted_cookies(
- test_client_factory: Callable[..., TestClient], domain: str, ok: bool
-):
+ test_client_factory: TestClientFactory, domain: str, ok: bool
+) -> None:
"""
Test that test client discards domain restricted cookies which do not match the
base_url of the testclient (`http://testserver` by default).
in accordance with RFC 2965.
"""
- async def app(scope: Scope, receive: Receive, send: Send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie(
"mycookie",
assert cookie_set == ok
-def test_forward_follow_redirects(test_client_factory: Callable[..., TestClient]):
- async def app(scope: Scope, receive: Receive, send: Send):
+def test_forward_follow_redirects(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
if "/ok" in scope["path"]:
response = Response("ok")
else:
assert response.status_code == 200
-def test_forward_nofollow_redirects(test_client_factory: Callable[..., TestClient]):
- async def app(scope: Scope, receive: Receive, send: Send):
+def test_forward_nofollow_redirects(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = RedirectResponse("/ok")
await response(scope, receive, send)
assert response.status_code == 307
-def test_with_duplicate_headers(test_client_factory: Callable[[Starlette], TestClient]):
+def test_with_duplicate_headers(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> JSONResponse:
return JSONResponse({"x-token": request.headers.getlist("x-token")})
assert response.json() == {"x-token": ["foo", "bar"]}
-def test_merge_url(test_client_factory: Callable[..., TestClient]):
+def test_merge_url(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> Response:
return Response(request.url.path)