From: florimondmanca Date: Tue, 14 Jun 2022 21:04:45 +0000 (+0200) Subject: Refactor, lint X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=9084939fea310461723d4f2a5db0936193b2c6ac;p=thirdparty%2Fstarlette.git Refactor, lint --- diff --git a/starlette/middleware/http.py b/starlette/middleware/http.py index 61571d54..8284e989 100644 --- a/starlette/middleware/http.py +++ b/starlette/middleware/http.py @@ -1,68 +1,79 @@ +from contextlib import aclosing from functools import partial -from typing import Generator, Optional -from ..types import ASGIApp, Scope, Receive, Send, Message -from ..responses import Response +from typing import AsyncGenerator, Callable, Optional + from ..datastructures import MutableHeaders +from ..responses import Response +from ..types import ASGIApp, Message, Receive, Scope, Send +HTTPDispatchFlow = AsyncGenerator[Optional[Response], Response] -class HTTPMiddleware: - DispatchFlow = Generator[Optional[Response], Response, None] - def __init__(self, app: ASGIApp) -> None: +class HTTPMiddleware: + def __init__( + self, + app: ASGIApp, + dispatch_func: Optional[Callable[[Scope], HTTPDispatchFlow]] = None, + ) -> None: self.app = app + self.dispatch_func = self.dispatch if dispatch_func is None else dispatch_func - def dispatch(self, scope: Scope) -> DispatchFlow: - raise NotImplementedError + def dispatch(self, scope: Scope) -> HTTPDispatchFlow: + raise NotImplementedError # pragma: no cover async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return - flow = self.dispatch(scope) - - try: - # Run until first `yield` to allow modifying the connection scope. - # Middleware may return a response before we call into the underlying app. - early_response = flow.__next__() - except StopIteration: - raise RuntimeError("dispatch() did not run") + async with aclosing(self.dispatch(scope)) as flow: + # Kick the flow until the first `yield`. + # Might respond early before we call into the app. + early_response = await flow.__anext__() - if early_response is not None: - await early_response(scope, receive, send) - return + if early_response is not None: + await early_response(scope, receive, send) + return - response_started = set[bool]() + response_started = set[bool]() - wrapped_send = partial( - self._send, - flow=flow, - response_started=response_started, - send=send, - ) + wrapped_send = partial( + self._send, + flow=flow, + response_started=response_started, + send=send, + ) - try: - await self.app(scope, receive, wrapped_send) - except Exception as exc: try: - response = flow.throw(exc) - except Exception: - # Exception was not handled, or they raised another one. - raise + await self.app(scope, receive, wrapped_send) + except Exception as exc: + if response_started: + raise + + try: + response = await flow.athrow(exc) + except Exception: + # Exception was not handled, or they raised another one. + raise - if response is None: - raise RuntimeError( - f"dispatch() handled exception {exc!r}, " - "but no response application was returned" - ) + if response is None: + raise RuntimeError( + f"dispatch() handled exception {exc!r}, " + "but no response was returned" + ) - await response(scope, receive, send) + await response(scope, receive, send) - if not response_started: - raise RuntimeError("No response returned.") + if not response_started: + raise RuntimeError("No response returned.") async def _send( - self, message: Message, *, flow: DispatchFlow, response_started: set, send: Send + self, + message: Message, + *, + flow: HTTPDispatchFlow, + response_started: set, + send: Send, ) -> None: if message["type"] == "http.response.start": response_started.add(True) @@ -71,10 +82,9 @@ class HTTPMiddleware: response.raw_headers.clear() try: - flow.send(response) - except StopIteration as exc: - if exc.value is not None: - raise RuntimeError("swapping responses it not supported") + await flow.asend(response) + except StopAsyncIteration: + pass else: raise RuntimeError("dispatch() should yield exactly once") diff --git a/tests/middleware/test_http.py b/tests/middleware/test_http.py index af951891..c0213c76 100644 --- a/tests/middleware/test_http.py +++ b/tests/middleware/test_http.py @@ -4,14 +4,14 @@ import pytest from starlette.applications import Starlette from starlette.middleware import Middleware -from starlette.middleware.http import HTTPMiddleware +from starlette.middleware.http import HTTPDispatchFlow, HTTPMiddleware from starlette.responses import PlainTextResponse, StreamingResponse from starlette.routing import Mount, Route, WebSocketRoute from starlette.types import ASGIApp, Receive, Scope, Send class CustomMiddleware(HTTPMiddleware): - def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + async def dispatch(self, scope: Scope) -> HTTPDispatchFlow: response = yield None response.headers["Custom-Header"] = "Example" @@ -88,18 +88,18 @@ def test_state_data_across_multiple_middlewares(test_client_factory): expected_value2 = "bar" class aMiddleware(HTTPMiddleware): - def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + async def dispatch(self, scope: Scope) -> HTTPDispatchFlow: scope["state_foo"] = expected_value1 yield None class bMiddleware(HTTPMiddleware): - def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + async def dispatch(self, scope: Scope) -> HTTPDispatchFlow: scope["state_bar"] = expected_value2 response = yield None response.headers["X-State-Foo"] = scope["state_foo"] class cMiddleware(HTTPMiddleware): - def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + async def dispatch(self, scope: Scope) -> HTTPDispatchFlow: response = yield None response.headers["X-State-Bar"] = scope["state_bar"] @@ -143,7 +143,7 @@ def test_middleware_repr(): def test_fully_evaluated_response(test_client_factory): # Test for https://github.com/encode/starlette/issues/1022 class CustomMiddleware(HTTPMiddleware): - def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + async def dispatch(self, scope: Scope) -> HTTPDispatchFlow: yield PlainTextResponse("Custom") app = Starlette(middleware=[Middleware(CustomMiddleware)]) @@ -177,7 +177,7 @@ class CustomMiddlewareWithoutBaseHTTPMiddleware: class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware): - def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + async def dispatch(self, scope: Scope) -> HTTPDispatchFlow: ctxvar.set("set by middleware") yield None assert ctxvar.get() == "set by endpoint"