From: florimondmanca Date: Wed, 15 Jun 2022 20:37:50 +0000 (+0200) Subject: Address X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=9e48c1f18a3b5a8731406d8add445dba1c5a1335;p=thirdparty%2Fstarlette.git Address --- diff --git a/starlette/_compat.py b/starlette/_compat.py index 43bba40a..76043174 100644 --- a/starlette/_compat.py +++ b/starlette/_compat.py @@ -1,5 +1,6 @@ import hashlib -from typing import Any +import sys +from typing import Any, AsyncContextManager __all__ = [ "md5_hexdigest", @@ -35,11 +36,11 @@ except TypeError: # pragma: no cover return hashlib.md5(data).hexdigest() -try: - from contextlib import aclosing # type: ignore[attr-defined] -except ImportError: # Python < 3.10 # pragma: no cover +if sys.version_info >= (3, 10): # pragma: no cover + from contextlib import aclosing +else: # pragma: no cover - class aclosing: # type: ignore + class aclosing(AsyncContextManager): def __init__(self, thing: Any) -> None: self.thing = thing diff --git a/starlette/middleware/http.py b/starlette/middleware/http.py index 16b7b35c..aa167526 100644 --- a/starlette/middleware/http.py +++ b/starlette/middleware/http.py @@ -6,20 +6,22 @@ from ..requests import HTTPConnection from ..responses import Response from ..types import ASGIApp, Message, Receive, Scope, Send +# This type hint not exposed, as it exists mostly for our own documentation purposes. +# End users should use one of these type hints explicitly when overriding '.dispatch()'. _DispatchFlow = Union[ # Default case: - # response = yield + # response = yield AsyncGenerator[None, Response], # Early response and/or error handling: - # if condition: - # yield Response(...) - # return - # try: - # response = yield None - # except Exception: - # yield Response(...) - # else: - # ... + # if condition: + # yield Response(...) + # return + # try: + # response = yield None + # except Exception: + # yield Response(...) + # else: + # ... AsyncGenerator[Optional[Response], Response], ] @@ -30,14 +32,15 @@ class HTTPMiddleware: app: ASGIApp, dispatch: Optional[Callable[[HTTPConnection], _DispatchFlow]] = None, ) -> None: - if dispatch is None: - dispatch = self.dispatch - self.app = app - self._dispatch_func = dispatch + self._dispatch_func = self.dispatch if dispatch is None else dispatch - def dispatch(self, conn: HTTPConnection) -> _DispatchFlow: - raise NotImplementedError # pragma: no cover + def dispatch(self, __conn: HTTPConnection) -> _DispatchFlow: + raise NotImplementedError( + "No dispatch implementation was given. " + "Either pass 'dispatch=...' to HTTPMiddleware, " + "or subclass HTTPMiddleware and override the 'dispatch()' method." + ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": @@ -52,14 +55,23 @@ class HTTPMiddleware: maybe_early_response = await flow.__anext__() if maybe_early_response is not None: + try: + await flow.__anext__() + except StopAsyncIteration: + pass + else: + raise RuntimeError("dispatch() should yield exactly once") + await maybe_early_response(scope, receive, send) return - response_started: set = set() + response_started = False async def wrapped_send(message: Message) -> None: + nonlocal response_started + if message["type"] == "http.response.start": - response_started.add(True) + response_started = True response = Response(status_code=message["status"]) response.raw_headers.clear() @@ -73,6 +85,7 @@ class HTTPMiddleware: headers = MutableHeaders(raw=message["headers"]) headers.update(response.headers) + message["headers"] = headers.raw await send(message) @@ -98,6 +111,3 @@ class HTTPMiddleware: await response(scope, receive, send) return - - if not response_started: - raise RuntimeError("No response returned.") diff --git a/tests/middleware/test_http.py b/tests/middleware/test_http.py index 5f83d313..cbe1c3ac 100644 --- a/tests/middleware/test_http.py +++ b/tests/middleware/test_http.py @@ -1,70 +1,61 @@ -import contextvars -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, Callable, Iterator, Optional import pytest from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.http import HTTPMiddleware -from starlette.requests import HTTPConnection +from starlette.requests import HTTPConnection, Request from starlette.responses import PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.testclient import TestClient +from starlette.types import ASGIApp +from starlette.websockets import WebSocket -class CustomMiddleware(HTTPMiddleware): - async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]: - response = yield - response.headers["Custom-Header"] = "Example" - - -def homepage(request): +async def homepage(request: Request) -> Response: return PlainTextResponse("Homepage") -def exc(request): +async def exc(request: Request) -> Response: raise Exception("Exc") -def exc_stream(request): +async def exc_stream(request: Request) -> Response: return StreamingResponse(_generate_faulty_stream()) -def _generate_faulty_stream(): +def _generate_faulty_stream() -> Iterator[bytes]: yield b"Ok" raise Exception("Faulty Stream") -class NoResponse: - def __init__(self, scope, receive, send): - pass - - def __await__(self): - return self.dispatch().__await__() - - async def dispatch(self): - pass - - -async def websocket_endpoint(session): +async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") await session.close() +class CustomMiddleware(HTTPMiddleware): + async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]: + response = yield + response.headers["Custom-Header"] = "Example" + + app = Starlette( routes=[ Route("/", endpoint=homepage), Route("/exc", endpoint=exc), Route("/exc-stream", endpoint=exc_stream), - Route("/no-response", endpoint=NoResponse), WebSocketRoute("/ws", endpoint=websocket_endpoint), ], middleware=[Middleware(CustomMiddleware)], ) -def test_custom_middleware(test_client_factory): +def test_custom_middleware( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: client = test_client_factory(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" @@ -77,49 +68,39 @@ def test_custom_middleware(test_client_factory): response = client.get("/exc-stream") assert str(ctx.value) == "Faulty Stream" - with pytest.raises(RuntimeError): - response = client.get("/no-response") - with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" -def test_state_data_across_multiple_middlewares(test_client_factory): +def test_state_data_across_multiple_middlewares( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + async def homepage(request: Request) -> Response: + return PlainTextResponse("OK") + expected_value1 = "foo" expected_value2 = "bar" - class aMiddleware(HTTPMiddleware): - async def dispatch( - self, conn: HTTPConnection - ) -> AsyncGenerator[None, Response]: - conn.state.foo = expected_value1 - yield - - class bMiddleware(HTTPMiddleware): - async def dispatch( - self, conn: HTTPConnection - ) -> AsyncGenerator[None, Response]: - conn.state.bar = expected_value2 - response = yield - response.headers["X-State-Foo"] = conn.state.foo + async def middleware_a(conn: HTTPConnection) -> AsyncGenerator[None, Response]: + conn.state.foo = expected_value1 + yield - class cMiddleware(HTTPMiddleware): - async def dispatch( - self, conn: HTTPConnection - ) -> AsyncGenerator[None, Response]: - response = yield - response.headers["X-State-Bar"] = conn.state.bar + async def middleware_b(conn: HTTPConnection) -> AsyncGenerator[None, Response]: + conn.state.bar = expected_value2 + response = yield + response.headers["X-State-Foo"] = conn.state.foo - def homepage(request): - return PlainTextResponse("OK") + async def middleware_c(conn: HTTPConnection) -> AsyncGenerator[None, Response]: + response = yield + response.headers["X-State-Bar"] = conn.state.bar app = Starlette( routes=[Route("/", homepage)], middleware=[ - Middleware(aMiddleware), - Middleware(bMiddleware), - Middleware(cMiddleware), + Middleware(HTTPMiddleware, dispatch=middleware_a), + Middleware(HTTPMiddleware, dispatch=middleware_b), + Middleware(HTTPMiddleware, dispatch=middleware_c), ], ) @@ -130,31 +111,23 @@ def test_state_data_across_multiple_middlewares(test_client_factory): assert response.headers["X-State-Bar"] == expected_value2 -def test_dispatch_argument(test_client_factory): - def homepage(request): - return PlainTextResponse("Homepage") - - async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]: - response = yield - response.headers["Custom-Header"] = "Example" +def test_too_many_yields(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: + class CustomMiddleware(HTTPMiddleware): + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[None, Response]: + _ = yield + yield - app = Starlette( - routes=[Route("/", homepage)], - middleware=[Middleware(HTTPMiddleware, dispatch=dispatch)], - ) + app = Starlette(middleware=[Middleware(CustomMiddleware)]) client = test_client_factory(app) - response = client.get("/") - assert response.headers["Custom-Header"] == "Example" - - -def test_middleware_repr(): - middleware = Middleware(CustomMiddleware) - assert repr(middleware) == "Middleware(CustomMiddleware)" + with pytest.raises(RuntimeError, match="should yield exactly once"): + client.get("/") -def test_early_response(test_client_factory): - async def index(request): +def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: + async def index(request: Request) -> Response: return PlainTextResponse("Hello, world!") class CustomMiddleware(HTTPMiddleware): @@ -179,13 +152,15 @@ def test_early_response(test_client_factory): assert response.status_code == 401 -def test_too_many_yields(test_client_factory) -> None: +def test_early_response_too_many_yields( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: class CustomMiddleware(HTTPMiddleware): async def dispatch( self, conn: HTTPConnection - ) -> AsyncGenerator[None, Response]: - _ = yield - yield + ) -> AsyncGenerator[Optional[Response], Response]: + yield Response() + yield None app = Starlette(middleware=[Middleware(CustomMiddleware)]) @@ -194,11 +169,11 @@ def test_too_many_yields(test_client_factory) -> None: client.get("/") -def test_error_response(test_client_factory) -> None: +def test_error_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: class Failed(Exception): pass - async def failure(request): + async def failure(request: Request) -> Response: raise Failed() class CustomMiddleware(HTTPMiddleware): @@ -221,11 +196,13 @@ def test_error_response(test_client_factory) -> None: assert response.status_code == 500 -def test_no_error_response(test_client_factory) -> None: +def test_error_handling_must_send_response( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: class Failed(Exception): pass - async def index(request): + async def index(request: Request) -> Response: raise Failed() class CustomMiddleware(HTTPMiddleware): @@ -235,7 +212,7 @@ def test_no_error_response(test_client_factory) -> None: try: yield except Failed: - pass + pass # `yield ` expected app = Starlette( routes=[Route("/", index)], @@ -247,43 +224,10 @@ def test_no_error_response(test_client_factory) -> None: client.get("/") -ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") - - -class PureASGICustomMiddleware: - def __init__(self, app: ASGIApp) -> None: - self.app = app - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - ctxvar.set("set by middleware") - await self.app(scope, receive, send) - assert ctxvar.get() == "set by endpoint" - - -class HTTPCustomMiddleware(HTTPMiddleware): - async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]: - ctxvar.set("set by middleware") - yield - assert ctxvar.get() == "set by endpoint" - - -@pytest.mark.parametrize( - "middleware_cls", - [ - PureASGICustomMiddleware, - HTTPCustomMiddleware, - ], -) -def test_contextvars(test_client_factory, middleware_cls: type): - async def homepage(request): - assert ctxvar.get() == "set by middleware" - ctxvar.set("set by endpoint") - return PlainTextResponse("Homepage") - - app = Starlette( - middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] - ) - +def test_no_dispatch_given( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + app = Starlette(middleware=[Middleware(HTTPMiddleware)]) client = test_client_factory(app) - response = client.get("/") - assert response.status_code == 200, response.content + with pytest.raises(NotImplementedError, match="No dispatch implementation"): + client.get("/")