]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Address
authorflorimondmanca <florimond.manca@protonmail.com>
Wed, 15 Jun 2022 20:37:50 +0000 (22:37 +0200)
committerflorimondmanca <florimond.manca@protonmail.com>
Wed, 15 Jun 2022 20:39:01 +0000 (22:39 +0200)
starlette/_compat.py
starlette/middleware/http.py
tests/middleware/test_http.py

index 43bba40ac41230e4a83ff5f527721eba37625931..760431745dab7b5f384dbffcc367985b3a4da40f 100644 (file)
@@ -1,5 +1,6 @@
 import hashlib
-from typing import Any
+import sys
+from typing import Any, AsyncContextManager
 
 __all__ = [
     "md5_hexdigest",
@@ -35,11 +36,11 @@ except TypeError:  # pragma: no cover
         return hashlib.md5(data).hexdigest()
 
 
-try:
-    from contextlib import aclosing  # type: ignore[attr-defined]
-except ImportError:  # Python < 3.10  # pragma: no cover
+if sys.version_info >= (3, 10):  # pragma: no cover
+    from contextlib import aclosing
+else:  # pragma: no cover
 
-    class aclosing:  # type: ignore
+    class aclosing(AsyncContextManager):
         def __init__(self, thing: Any) -> None:
             self.thing = thing
 
index 16b7b35cb32653d3d000dcb9db26904f1a6ca67d..aa16752639b4d1eb1207b2b07c684623568c6540 100644 (file)
@@ -6,20 +6,22 @@ from ..requests import HTTPConnection
 from ..responses import Response
 from ..types import ASGIApp, Message, Receive, Scope, Send
 
+# This type hint not exposed, as it exists mostly for our own documentation purposes.
+# End users should use one of these type hints explicitly when overriding '.dispatch()'.
 _DispatchFlow = Union[
     # Default case:
-    # response = yield
+    #   response = yield
     AsyncGenerator[None, Response],
     # Early response and/or error handling:
-    # if condition:
-    #     yield Response(...)
-    #     return
-    # try:
-    #     response = yield None
-    # except Exception:
-    #     yield Response(...)
-    # else:
-    #    ...
+    #   if condition:
+    #       yield Response(...)
+    #       return
+    #   try:
+    #       response = yield None
+    #   except Exception:
+    #       yield Response(...)
+    #   else:
+    #       ...
     AsyncGenerator[Optional[Response], Response],
 ]
 
@@ -30,14 +32,15 @@ class HTTPMiddleware:
         app: ASGIApp,
         dispatch: Optional[Callable[[HTTPConnection], _DispatchFlow]] = None,
     ) -> None:
-        if dispatch is None:
-            dispatch = self.dispatch
-
         self.app = app
-        self._dispatch_func = dispatch
+        self._dispatch_func = self.dispatch if dispatch is None else dispatch
 
-    def dispatch(self, conn: HTTPConnection) -> _DispatchFlow:
-        raise NotImplementedError  # pragma: no cover
+    def dispatch(self, __conn: HTTPConnection) -> _DispatchFlow:
+        raise NotImplementedError(
+            "No dispatch implementation was given. "
+            "Either pass 'dispatch=...' to HTTPMiddleware, "
+            "or subclass HTTPMiddleware and override the 'dispatch()' method."
+        )
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         if scope["type"] != "http":
@@ -52,14 +55,23 @@ class HTTPMiddleware:
             maybe_early_response = await flow.__anext__()
 
             if maybe_early_response is not None:
+                try:
+                    await flow.__anext__()
+                except StopAsyncIteration:
+                    pass
+                else:
+                    raise RuntimeError("dispatch() should yield exactly once")
+
                 await maybe_early_response(scope, receive, send)
                 return
 
-            response_started: set = set()
+            response_started = False
 
             async def wrapped_send(message: Message) -> None:
+                nonlocal response_started
+
                 if message["type"] == "http.response.start":
-                    response_started.add(True)
+                    response_started = True
 
                     response = Response(status_code=message["status"])
                     response.raw_headers.clear()
@@ -73,6 +85,7 @@ class HTTPMiddleware:
 
                     headers = MutableHeaders(raw=message["headers"])
                     headers.update(response.headers)
+                    message["headers"] = headers.raw
 
                 await send(message)
 
@@ -98,6 +111,3 @@ class HTTPMiddleware:
 
                 await response(scope, receive, send)
                 return
-
-            if not response_started:
-                raise RuntimeError("No response returned.")
index 5f83d313feb4a9469f5399b386376d1b05acd5f2..cbe1c3ac73b0fef1a8d1103b7fd9fbfa6027367c 100644 (file)
@@ -1,70 +1,61 @@
-import contextvars
-from typing import AsyncGenerator, Optional
+from typing import AsyncGenerator, Callable, Iterator, Optional
 
 import pytest
 
 from starlette.applications import Starlette
 from starlette.middleware import Middleware
 from starlette.middleware.http import HTTPMiddleware
-from starlette.requests import HTTPConnection
+from starlette.requests import HTTPConnection, Request
 from starlette.responses import PlainTextResponse, Response, StreamingResponse
 from starlette.routing import Route, WebSocketRoute
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.testclient import TestClient
+from starlette.types import ASGIApp
+from starlette.websockets import WebSocket
 
 
-class CustomMiddleware(HTTPMiddleware):
-    async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
-        response = yield
-        response.headers["Custom-Header"] = "Example"
-
-
-def homepage(request):
+async def homepage(request: Request) -> Response:
     return PlainTextResponse("Homepage")
 
 
-def exc(request):
+async def exc(request: Request) -> Response:
     raise Exception("Exc")
 
 
-def exc_stream(request):
+async def exc_stream(request: Request) -> Response:
     return StreamingResponse(_generate_faulty_stream())
 
 
-def _generate_faulty_stream():
+def _generate_faulty_stream() -> Iterator[bytes]:
     yield b"Ok"
     raise Exception("Faulty Stream")
 
 
-class NoResponse:
-    def __init__(self, scope, receive, send):
-        pass
-
-    def __await__(self):
-        return self.dispatch().__await__()
-
-    async def dispatch(self):
-        pass
-
-
-async def websocket_endpoint(session):
+async def websocket_endpoint(session: WebSocket) -> None:
     await session.accept()
     await session.send_text("Hello, world!")
     await session.close()
 
 
+class CustomMiddleware(HTTPMiddleware):
+    async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+        response = yield
+        response.headers["Custom-Header"] = "Example"
+
+
 app = Starlette(
     routes=[
         Route("/", endpoint=homepage),
         Route("/exc", endpoint=exc),
         Route("/exc-stream", endpoint=exc_stream),
-        Route("/no-response", endpoint=NoResponse),
         WebSocketRoute("/ws", endpoint=websocket_endpoint),
     ],
     middleware=[Middleware(CustomMiddleware)],
 )
 
 
-def test_custom_middleware(test_client_factory):
+def test_custom_middleware(
+    test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
     client = test_client_factory(app)
     response = client.get("/")
     assert response.headers["Custom-Header"] == "Example"
@@ -77,49 +68,39 @@ def test_custom_middleware(test_client_factory):
         response = client.get("/exc-stream")
     assert str(ctx.value) == "Faulty Stream"
 
-    with pytest.raises(RuntimeError):
-        response = client.get("/no-response")
-
     with client.websocket_connect("/ws") as session:
         text = session.receive_text()
         assert text == "Hello, world!"
 
 
-def test_state_data_across_multiple_middlewares(test_client_factory):
+def test_state_data_across_multiple_middlewares(
+    test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
+    async def homepage(request: Request) -> Response:
+        return PlainTextResponse("OK")
+
     expected_value1 = "foo"
     expected_value2 = "bar"
 
-    class aMiddleware(HTTPMiddleware):
-        async def dispatch(
-            self, conn: HTTPConnection
-        ) -> AsyncGenerator[None, Response]:
-            conn.state.foo = expected_value1
-            yield
-
-    class bMiddleware(HTTPMiddleware):
-        async def dispatch(
-            self, conn: HTTPConnection
-        ) -> AsyncGenerator[None, Response]:
-            conn.state.bar = expected_value2
-            response = yield
-            response.headers["X-State-Foo"] = conn.state.foo
+    async def middleware_a(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+        conn.state.foo = expected_value1
+        yield
 
-    class cMiddleware(HTTPMiddleware):
-        async def dispatch(
-            self, conn: HTTPConnection
-        ) -> AsyncGenerator[None, Response]:
-            response = yield
-            response.headers["X-State-Bar"] = conn.state.bar
+    async def middleware_b(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+        conn.state.bar = expected_value2
+        response = yield
+        response.headers["X-State-Foo"] = conn.state.foo
 
-    def homepage(request):
-        return PlainTextResponse("OK")
+    async def middleware_c(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+        response = yield
+        response.headers["X-State-Bar"] = conn.state.bar
 
     app = Starlette(
         routes=[Route("/", homepage)],
         middleware=[
-            Middleware(aMiddleware),
-            Middleware(bMiddleware),
-            Middleware(cMiddleware),
+            Middleware(HTTPMiddleware, dispatch=middleware_a),
+            Middleware(HTTPMiddleware, dispatch=middleware_b),
+            Middleware(HTTPMiddleware, dispatch=middleware_c),
         ],
     )
 
@@ -130,31 +111,23 @@ def test_state_data_across_multiple_middlewares(test_client_factory):
     assert response.headers["X-State-Bar"] == expected_value2
 
 
-def test_dispatch_argument(test_client_factory):
-    def homepage(request):
-        return PlainTextResponse("Homepage")
-
-    async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
-        response = yield
-        response.headers["Custom-Header"] = "Example"
+def test_too_many_yields(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
+    class CustomMiddleware(HTTPMiddleware):
+        async def dispatch(
+            self, conn: HTTPConnection
+        ) -> AsyncGenerator[None, Response]:
+            _ = yield
+            yield
 
-    app = Starlette(
-        routes=[Route("/", homepage)],
-        middleware=[Middleware(HTTPMiddleware, dispatch=dispatch)],
-    )
+    app = Starlette(middleware=[Middleware(CustomMiddleware)])
 
     client = test_client_factory(app)
-    response = client.get("/")
-    assert response.headers["Custom-Header"] == "Example"
-
-
-def test_middleware_repr():
-    middleware = Middleware(CustomMiddleware)
-    assert repr(middleware) == "Middleware(CustomMiddleware)"
+    with pytest.raises(RuntimeError, match="should yield exactly once"):
+        client.get("/")
 
 
-def test_early_response(test_client_factory):
-    async def index(request):
+def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
+    async def index(request: Request) -> Response:
         return PlainTextResponse("Hello, world!")
 
     class CustomMiddleware(HTTPMiddleware):
@@ -179,13 +152,15 @@ def test_early_response(test_client_factory):
     assert response.status_code == 401
 
 
-def test_too_many_yields(test_client_factory) -> None:
+def test_early_response_too_many_yields(
+    test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
     class CustomMiddleware(HTTPMiddleware):
         async def dispatch(
             self, conn: HTTPConnection
-        ) -> AsyncGenerator[None, Response]:
-            _ = yield
-            yield
+        ) -> AsyncGenerator[Optional[Response], Response]:
+            yield Response()
+            yield None
 
     app = Starlette(middleware=[Middleware(CustomMiddleware)])
 
@@ -194,11 +169,11 @@ def test_too_many_yields(test_client_factory) -> None:
         client.get("/")
 
 
-def test_error_response(test_client_factory) -> None:
+def test_error_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
     class Failed(Exception):
         pass
 
-    async def failure(request):
+    async def failure(request: Request) -> Response:
         raise Failed()
 
     class CustomMiddleware(HTTPMiddleware):
@@ -221,11 +196,13 @@ def test_error_response(test_client_factory) -> None:
     assert response.status_code == 500
 
 
-def test_no_error_response(test_client_factory) -> None:
+def test_error_handling_must_send_response(
+    test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
     class Failed(Exception):
         pass
 
-    async def index(request):
+    async def index(request: Request) -> Response:
         raise Failed()
 
     class CustomMiddleware(HTTPMiddleware):
@@ -235,7 +212,7 @@ def test_no_error_response(test_client_factory) -> None:
             try:
                 yield
             except Failed:
-                pass
+                pass  # `yield <response>` expected
 
     app = Starlette(
         routes=[Route("/", index)],
@@ -247,43 +224,10 @@ def test_no_error_response(test_client_factory) -> None:
         client.get("/")
 
 
-ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
-
-
-class PureASGICustomMiddleware:
-    def __init__(self, app: ASGIApp) -> None:
-        self.app = app
-
-    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
-        ctxvar.set("set by middleware")
-        await self.app(scope, receive, send)
-        assert ctxvar.get() == "set by endpoint"
-
-
-class HTTPCustomMiddleware(HTTPMiddleware):
-    async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
-        ctxvar.set("set by middleware")
-        yield
-        assert ctxvar.get() == "set by endpoint"
-
-
-@pytest.mark.parametrize(
-    "middleware_cls",
-    [
-        PureASGICustomMiddleware,
-        HTTPCustomMiddleware,
-    ],
-)
-def test_contextvars(test_client_factory, middleware_cls: type):
-    async def homepage(request):
-        assert ctxvar.get() == "set by middleware"
-        ctxvar.set("set by endpoint")
-        return PlainTextResponse("Homepage")
-
-    app = Starlette(
-        middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
-    )
-
+def test_no_dispatch_given(
+    test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
+    app = Starlette(middleware=[Middleware(HTTPMiddleware)])
     client = test_client_factory(app)
-    response = client.get("/")
-    assert response.status_code == 200, response.content
+    with pytest.raises(NotImplementedError, match="No dispatch implementation"):
+        client.get("/")