methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
include_in_schema: bool = True,
+ middleware: typing.Optional[typing.Sequence[Middleware]] = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint
+ if middleware is not None:
+ for cls, options in reversed(middleware):
+ self.app = cls(app=self.app, **options)
+
if methods is None:
self.methods = None
else:
endpoint: typing.Callable[..., typing.Any],
*,
name: typing.Optional[str] = None,
+ middleware: typing.Optional[typing.Sequence[Middleware]] = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint
+ if middleware is not None:
+ for cls, options in reversed(middleware):
+ self.app = cls(app=self.app, **options)
+
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
def websocket_connect(
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
- ) -> typing.Any:
+ ) -> "WebSocketTestSession":
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
headers.setdefault("connection", "upgrade")
return Response()
+route_with_middleware = Starlette(
+ routes=[
+ Route(
+ "/http",
+ endpoint=assert_middleware_header_route,
+ methods=["GET"],
+ middleware=[Middleware(AddHeadersMiddleware)],
+ ),
+ Route("/home", homepage),
+ ]
+)
+
mounted_routes_with_middleware = Starlette(
routes=[
Mount(
[
mounted_routes_with_middleware,
mounted_app_with_middleware,
+ route_with_middleware,
],
)
-def test_mount_middleware(
+def test_base_route_middleware(
test_client_factory: typing.Callable[..., TestClient],
app: Starlette,
) -> None:
assert "X-Mounted" in resp.headers
+def test_websocket_route_middleware(
+ test_client_factory: typing.Callable[..., TestClient]
+):
+ async def websocket_endpoint(session: WebSocket):
+ await session.accept()
+ await session.send_text("Hello, world!")
+ await session.close()
+
+ class WebsocketMiddleware:
+ def __init__(self, app: ASGIApp) -> None:
+ self.app = app
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ async def modified_send(msg: Message) -> None:
+ if msg["type"] == "websocket.accept":
+ msg["headers"].append((b"X-Test", b"Set by middleware"))
+ await send(msg)
+
+ await self.app(scope, receive, modified_send)
+
+ app = Starlette(
+ routes=[
+ WebSocketRoute(
+ "/ws",
+ endpoint=websocket_endpoint,
+ middleware=[Middleware(WebsocketMiddleware)],
+ )
+ ]
+ )
+
+ client = test_client_factory(app)
+
+ with client.websocket_connect("/ws") as websocket:
+ text = websocket.receive_text()
+ assert text == "Hello, world!"
+ assert websocket.extra_headers == [(b"X-Test", b"Set by middleware")]
+
+
def test_route_repr() -> None:
route = Route("/welcome", endpoint=homepage)
assert (