]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Refactor, lint
authorflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 21:04:45 +0000 (23:04 +0200)
committerflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 21:16:17 +0000 (23:16 +0200)
starlette/middleware/http.py
tests/middleware/test_http.py

index 61571d5491dedc9f26899b9c69e2531526cf1831..8284e98917a6a820a5708df7ace6321029dffa52 100644 (file)
@@ -1,68 +1,79 @@
+from contextlib import aclosing
 from functools import partial
-from typing import Generator, Optional
-from ..types import ASGIApp, Scope, Receive, Send, Message
-from ..responses import Response
+from typing import AsyncGenerator, Callable, Optional
+
 from ..datastructures import MutableHeaders
+from ..responses import Response
+from ..types import ASGIApp, Message, Receive, Scope, Send
 
+HTTPDispatchFlow = AsyncGenerator[Optional[Response], Response]
 
-class HTTPMiddleware:
-    DispatchFlow = Generator[Optional[Response], Response, None]
 
-    def __init__(self, app: ASGIApp) -> None:
+class HTTPMiddleware:
+    def __init__(
+        self,
+        app: ASGIApp,
+        dispatch_func: Optional[Callable[[Scope], HTTPDispatchFlow]] = None,
+    ) -> None:
         self.app = app
+        self.dispatch_func = self.dispatch if dispatch_func is None else dispatch_func
 
-    def dispatch(self, scope: Scope) -> DispatchFlow:
-        raise NotImplementedError
+    def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+        raise NotImplementedError  # pragma: no cover
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         if scope["type"] != "http":
             await self.app(scope, receive, send)
             return
 
-        flow = self.dispatch(scope)
-
-        try:
-            # Run until first `yield` to allow modifying the connection scope.
-            # Middleware may return a response before we call into the underlying app.
-            early_response = flow.__next__()
-        except StopIteration:
-            raise RuntimeError("dispatch() did not run")
+        async with aclosing(self.dispatch(scope)) as flow:
+            # Kick the flow until the first `yield`.
+            # Might respond early before we call into the app.
+            early_response = await flow.__anext__()
 
-        if early_response is not None:
-            await early_response(scope, receive, send)
-            return
+            if early_response is not None:
+                await early_response(scope, receive, send)
+                return
 
-        response_started = set[bool]()
+            response_started = set[bool]()
 
-        wrapped_send = partial(
-            self._send,
-            flow=flow,
-            response_started=response_started,
-            send=send,
-        )
+            wrapped_send = partial(
+                self._send,
+                flow=flow,
+                response_started=response_started,
+                send=send,
+            )
 
-        try:
-            await self.app(scope, receive, wrapped_send)
-        except Exception as exc:
             try:
-                response = flow.throw(exc)
-            except Exception:
-                # Exception was not handled, or they raised another one.
-                raise
+                await self.app(scope, receive, wrapped_send)
+            except Exception as exc:
+                if response_started:
+                    raise
+
+                try:
+                    response = await flow.athrow(exc)
+                except Exception:
+                    # Exception was not handled, or they raised another one.
+                    raise
 
-            if response is None:
-                raise RuntimeError(
-                    f"dispatch() handled exception {exc!r}, "
-                    "but no response application was returned"
-                )
+                if response is None:
+                    raise RuntimeError(
+                        f"dispatch() handled exception {exc!r}, "
+                        "but no response was returned"
+                    )
 
-            await response(scope, receive, send)
+                await response(scope, receive, send)
 
-        if not response_started:
-            raise RuntimeError("No response returned.")
+            if not response_started:
+                raise RuntimeError("No response returned.")
 
     async def _send(
-        self, message: Message, *, flow: DispatchFlow, response_started: set, send: Send
+        self,
+        message: Message,
+        *,
+        flow: HTTPDispatchFlow,
+        response_started: set,
+        send: Send,
     ) -> None:
         if message["type"] == "http.response.start":
             response_started.add(True)
@@ -71,10 +82,9 @@ class HTTPMiddleware:
             response.raw_headers.clear()
 
             try:
-                flow.send(response)
-            except StopIteration as exc:
-                if exc.value is not None:
-                    raise RuntimeError("swapping responses it not supported")
+                await flow.asend(response)
+            except StopAsyncIteration:
+                pass
             else:
                 raise RuntimeError("dispatch() should yield exactly once")
 
index af951891628127d3ee4e94b4c27cf356535a966e..c0213c76d42a2f57285d2093fbbaadc22ca4159d 100644 (file)
@@ -4,14 +4,14 @@ import pytest
 
 from starlette.applications import Starlette
 from starlette.middleware import Middleware
-from starlette.middleware.http import HTTPMiddleware
+from starlette.middleware.http import HTTPDispatchFlow, HTTPMiddleware
 from starlette.responses import PlainTextResponse, StreamingResponse
 from starlette.routing import Mount, Route, WebSocketRoute
 from starlette.types import ASGIApp, Receive, Scope, Send
 
 
 class CustomMiddleware(HTTPMiddleware):
-    def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+    async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
         response = yield None
         response.headers["Custom-Header"] = "Example"
 
@@ -88,18 +88,18 @@ def test_state_data_across_multiple_middlewares(test_client_factory):
     expected_value2 = "bar"
 
     class aMiddleware(HTTPMiddleware):
-        def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+        async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
             scope["state_foo"] = expected_value1
             yield None
 
     class bMiddleware(HTTPMiddleware):
-        def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+        async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
             scope["state_bar"] = expected_value2
             response = yield None
             response.headers["X-State-Foo"] = scope["state_foo"]
 
     class cMiddleware(HTTPMiddleware):
-        def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+        async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
             response = yield None
             response.headers["X-State-Bar"] = scope["state_bar"]
 
@@ -143,7 +143,7 @@ def test_middleware_repr():
 def test_fully_evaluated_response(test_client_factory):
     # Test for https://github.com/encode/starlette/issues/1022
     class CustomMiddleware(HTTPMiddleware):
-        def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+        async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
             yield PlainTextResponse("Custom")
 
     app = Starlette(middleware=[Middleware(CustomMiddleware)])
@@ -177,7 +177,7 @@ class CustomMiddlewareWithoutBaseHTTPMiddleware:
 
 
 class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware):
-    def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+    async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
         ctxvar.set("set by middleware")
         yield None
         assert ctxvar.get() == "set by endpoint"