import contextvars
from contextlib import AsyncExitStack
-from typing import Any, AsyncGenerator, Awaitable, Callable, List, Type, Union
+from typing import (
+ Any,
+ AsyncGenerator,
+ Callable,
+ Generator,
+ List,
+ Type,
+ Union,
+)
import anyio
import pytest
+from anyio.abc import TaskStatus
from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from starlette.websockets import WebSocket
+
+TestClientFactory = Callable[[ASGIApp], TestClient]
class CustomMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request, call_next):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
response = await call_next(request)
response.headers["Custom-Header"] = "Example"
return response
-def homepage(request):
+def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage")
-def exc(request):
+def exc(request: Request) -> None:
raise Exception("Exc")
-def exc_stream(request):
+def exc_stream(request: Request) -> StreamingResponse:
return StreamingResponse(_generate_faulty_stream())
-def _generate_faulty_stream():
+def _generate_faulty_stream() -> Generator[bytes, None, None]:
yield b"Ok"
raise Exception("Faulty Stream")
class NoResponse:
- def __init__(self, scope, receive, send):
+ def __init__(
+ self,
+ scope: Scope,
+ receive: Receive,
+ send: Send,
+ ):
pass
- def __await__(self):
+ def __await__(self) -> Generator[Any, None, None]:
return self.dispatch().__await__()
- async def dispatch(self):
+ async def dispatch(self) -> None:
pass
-async def websocket_endpoint(session):
+async def websocket_endpoint(session: WebSocket) -> None:
await session.accept()
await session.send_text("Hello, world!")
await session.close()
)
-def test_custom_middleware(test_client_factory):
+def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"
assert text == "Hello, world!"
-def test_state_data_across_multiple_middlewares(test_client_factory):
+def test_state_data_across_multiple_middlewares(
+ test_client_factory: TestClientFactory,
+) -> None:
expected_value1 = "foo"
expected_value2 = "bar"
class aMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request, call_next):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
request.state.foo = expected_value1
response = await call_next(request)
return response
class bMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request, call_next):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
request.state.bar = expected_value2
response = await call_next(request)
response.headers["X-State-Foo"] = request.state.foo
return response
class cMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request, call_next):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
response = await call_next(request)
response.headers["X-State-Bar"] = request.state.bar
return response
- def homepage(request):
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("OK")
app = Starlette(
assert response.headers["X-State-Bar"] == expected_value2
-def test_app_middleware_argument(test_client_factory):
- def homepage(request):
+def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage")
app = Starlette(
assert response.headers["Custom-Header"] == "Example"
-def test_fully_evaluated_response(test_client_factory):
+def test_fully_evaluated_response(test_client_factory: TestClientFactory) -> None:
# Test for https://github.com/encode/starlette/issues/1022
class CustomMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request, call_next):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> PlainTextResponse:
await call_next(request)
return PlainTextResponse("Custom")
class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request, call_next):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
ctxvar.set("set by middleware")
resp = await call_next(request)
assert ctxvar.get() == "set by endpoint"
),
],
)
-def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[Any]]):
+def test_contextvars(
+ test_client_factory: TestClientFactory,
+ middleware_cls: Type[_MiddlewareClass[Any]],
+) -> None:
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
- async def homepage(request):
+ async def homepage(request: Request) -> PlainTextResponse:
assert ctxvar.get() == "set by middleware"
ctxvar.set("set by endpoint")
return PlainTextResponse("Homepage")
@pytest.mark.anyio
-async def test_run_background_tasks_even_if_client_disconnects():
+async def test_run_background_tasks_even_if_client_disconnects() -> None:
# test for https://github.com/encode/starlette/issues/1438
request_body_sent = False
response_complete = anyio.Event()
background_task_run = anyio.Event()
- async def sleep_and_set():
+ async def sleep_and_set() -> None:
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
background_task_run.set()
- async def endpoint_with_background_task(_):
+ async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
return PlainTextResponse(background=BackgroundTask(sleep_and_set))
- async def passthrough(request, call_next):
+ async def passthrough(
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
return await call_next(request)
app = Starlette(
"path": "/",
}
- async def receive():
+ async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
await response_complete.wait()
return {"type": "http.disconnect"}
- async def send(message):
+ async def send(message: Message) -> None:
if message["type"] == "http.response.body":
if not message.get("more_body", False):
response_complete.set()
@pytest.mark.anyio
-async def test_do_not_block_on_background_tasks():
+async def test_do_not_block_on_background_tasks() -> None:
request_body_sent = False
response_complete = anyio.Event()
events: List[Union[str, Message]] = []
- async def sleep_and_set():
+ async def sleep_and_set() -> None:
events.append("Background task started")
await anyio.sleep(0.1)
events.append("Background task finished")
- async def endpoint_with_background_task(_):
+ async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
return PlainTextResponse(
content="Hello", background=BackgroundTask(sleep_and_set)
)
async def passthrough(
- request: Request, call_next: Callable[[Request], Awaitable[Response]]
+ request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)
await response_complete.wait()
return {"type": "http.disconnect"}
- async def send(message: Message):
+ async def send(message: Message) -> None:
if message["type"] == "http.response.body":
events.append(message)
if not message.get("more_body", False):
@pytest.mark.anyio
-async def test_run_context_manager_exit_even_if_client_disconnects():
+async def test_run_context_manager_exit_even_if_client_disconnects() -> None:
# test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
request_body_sent = False
response_complete = anyio.Event()
context_manager_exited = anyio.Event()
- async def sleep_and_set():
+ async def sleep_and_set() -> None:
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
context_manager_exited.set()
class ContextManagerMiddleware:
- def __init__(self, app):
+ def __init__(self, app: ASGIApp):
self.app = app
- async def __call__(self, scope: Scope, receive: Receive, send: Send):
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with AsyncExitStack() as stack:
stack.push_async_callback(sleep_and_set)
await self.app(scope, receive, send)
- async def simple_endpoint(_):
+ async def simple_endpoint(_: Request) -> PlainTextResponse:
return PlainTextResponse(background=BackgroundTask(sleep_and_set))
- async def passthrough(request, call_next):
+ async def passthrough(
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
return await call_next(request)
app = Starlette(
"path": "/",
}
- async def receive():
+ async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
await response_complete.wait()
return {"type": "http.disconnect"}
- async def send(message):
+ async def send(message: Message) -> None:
if message["type"] == "http.response.body":
if not message.get("more_body", False):
response_complete.set()
assert context_manager_exited.is_set()
-def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory):
+def test_app_receives_http_disconnect_while_sending_if_discarded(
+ test_client_factory: TestClientFactory,
+) -> None:
class DiscardingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request, call_next):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: Any,
+ ) -> PlainTextResponse:
# As a matter of ordering, this test targets the case where the downstream
# app response is discarded while it is sending a response body.
# We need to wait for the downstream app to begin sending a response body
return PlainTextResponse("Custom")
- async def downstream_app(scope, receive, send):
+ async def downstream_app(
+ scope: Scope,
+ receive: Receive,
+ send: Send,
+ ) -> None:
await send(
{
"type": "http.response.start",
)
async with anyio.create_task_group() as task_group:
- async def cancel_on_disconnect(*, task_status=anyio.TASK_STATUS_IGNORED):
+ async def cancel_on_disconnect(
+ *,
+ task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
+ ) -> None:
task_status.started()
while True:
message = await receive()
assert response.text == "Custom"
-def test_app_receives_http_disconnect_after_sending_if_discarded(test_client_factory):
+def test_app_receives_http_disconnect_after_sending_if_discarded(
+ test_client_factory: TestClientFactory,
+) -> None:
class DiscardingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request, call_next):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> PlainTextResponse:
await call_next(request)
return PlainTextResponse("Custom")
- async def downstream_app(scope, receive, send):
+ async def downstream_app(
+ scope: Scope,
+ receive: Receive,
+ send: Send,
+ ) -> None:
await send(
{
"type": "http.response.start",
def test_read_request_stream_in_app_after_middleware_calls_stream(
- test_client_factory: Callable[[ASGIApp], TestClient],
+ test_client_factory: TestClientFactory,
) -> None:
- async def homepage(request: Request):
+ async def homepage(request: Request) -> PlainTextResponse:
expected = [b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
def test_read_request_stream_in_app_after_middleware_calls_body(
- test_client_factory: Callable[[ASGIApp], TestClient],
+ test_client_factory: TestClientFactory,
) -> None:
- async def homepage(request: Request):
+ async def homepage(request: Request) -> PlainTextResponse:
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
assert await request.body() == b"a"
return await call_next(request)
def test_read_request_body_in_app_after_middleware_calls_stream(
- test_client_factory: Callable[[ASGIApp], TestClient],
+ test_client_factory: TestClientFactory,
) -> None:
- async def homepage(request: Request):
+ async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b""
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
def test_read_request_body_in_app_after_middleware_calls_body(
- test_client_factory: Callable[[ASGIApp], TestClient],
+ test_client_factory: TestClientFactory,
) -> None:
- async def homepage(request: Request):
+ async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
assert await request.body() == b"a"
return await call_next(request)
def test_read_request_stream_in_dispatch_after_app_calls_stream(
- test_client_factory: Callable[[ASGIApp], TestClient],
+ test_client_factory: TestClientFactory,
) -> None:
- async def homepage(request: Request):
+ async def homepage(request: Request) -> PlainTextResponse:
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
resp = await call_next(request)
with pytest.raises(RuntimeError, match="Stream consumed"):
async for _ in request.stream():
def test_read_request_stream_in_dispatch_after_app_calls_body(
- test_client_factory: Callable[[ASGIApp], TestClient],
+ test_client_factory: TestClientFactory,
) -> None:
- async def homepage(request: Request):
+ async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
resp = await call_next(request)
with pytest.raises(RuntimeError, match="Stream consumed"):
async for _ in request.stream():
await Response()(scope, receive, send)
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
expected = b"1"
response: Union[Response, None] = None
async for chunk in request.stream():
def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501
- test_client_factory: Callable[[ASGIApp], TestClient],
+ test_client_factory: TestClientFactory,
) -> None:
- async def homepage(request: Request):
+ async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
assert (
await request.body() == b"a"
) # this buffers the request body in memory
def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501
- test_client_factory: Callable[[ASGIApp], TestClient],
+ test_client_factory: TestClientFactory,
) -> None:
- async def homepage(request: Request):
+ async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
assert (
await request.body() == b"a"
) # this buffers the request body in memory
await Response()(scope, receive, send)
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
response = await call_next(request)
disconnected = await request.is_disconnected()
assert disconnected is True
yield {"type": "http.disconnect"}
raise AssertionError("Should not be called, would hang") # pragma: no cover
- async def send(msg: Message):
+ async def send(msg: Message) -> None:
if msg["type"] == "http.response.start":
assert msg["status"] == 200
await Response()(scope, receive, send)
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
await request.body()
disconnected = await request.is_disconnected()
assert disconnected is True
yield {"type": "http.disconnect"}
raise AssertionError("Should not be called, would hang") # pragma: no cover
- async def send(msg: Message):
+ async def send(msg: Message) -> None:
if msg["type"] == "http.response.start":
assert msg["status"] == 200
def test_downstream_middleware_modifies_receive(
- test_client_factory: Callable[[ASGIApp], TestClient],
+ test_client_factory: TestClientFactory,
) -> None:
"""If a downstream middleware modifies receive() the final ASGI app
should see the modified version.
await Response()(scope, receive, send)
class ConsumingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
body = await request.body()
assert body == b"foo "
return await call_next(request)
assert resp.status_code == 200
-CallNext = Callable[[Request], Awaitable[Response]]
-
-
def test_pr_1519_comment_1236166180_example() -> None:
"""
https://github.com/encode/starlette/pull/1519#issuecomment-1236166180
bodies: List[bytes] = []
class LogRequestBodySize(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: CallNext) -> Response:
+ async def dispatch(
+ self,
+ request: Request,
+ call_next: RequestResponseEndpoint,
+ ) -> Response:
print(len(await request.body()))
return await call_next(request)