From: florimondmanca Date: Tue, 14 Jun 2022 21:59:04 +0000 (+0200) Subject: Address, refactor X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=cf163039372466ad93c566cabd522b8dff844ee8;p=thirdparty%2Fstarlette.git Address, refactor --- diff --git a/setup.py b/setup.py index 2ee644d4..1597ef45 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ setup( install_requires=[ "anyio>=3.4.0,<5", "typing_extensions>=3.10.0; python_version < '3.10'", - "async_generator; python_version < '3.10'", ], extras_require={ "full": [ diff --git a/starlette/_compat.py b/starlette/_compat.py index b1446040..569c06ff 100644 --- a/starlette/_compat.py +++ b/starlette/_compat.py @@ -1,4 +1,5 @@ import hashlib +from typing import Any __all__ = [ "md5_hexdigest", @@ -37,4 +38,13 @@ except TypeError: # pragma: no cover try: from contextlib import aclosing # type: ignore[attr-defined] except ImportError: # Python < 3.10 - from async_generator import aclosing # type: ignore + + class aclosing: # type: ignore + def __init__(self, thing: Any) -> None: + self.thing = thing + + async def __aenter__(self) -> Any: + return self.thing + + async def __aexit__(self, *exc_info: Any) -> None: + await self.thing.aclose() diff --git a/starlette/middleware/http.py b/starlette/middleware/http.py index 01b9fed8..ce4a7373 100644 --- a/starlette/middleware/http.py +++ b/starlette/middleware/http.py @@ -6,7 +6,9 @@ from ..responses import Response from ..types import ASGIApp, Message, Receive, Scope, Send _HTTPDispatchFlow = Union[ - AsyncGenerator[None, Response], AsyncGenerator[ASGIApp, Response] + AsyncGenerator[None, Response], + AsyncGenerator[Response, Response], + AsyncGenerator[Optional[Response], Response], ] @@ -14,10 +16,10 @@ class HTTPMiddleware: def __init__( self, app: ASGIApp, - dispatch_func: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None, + dispatch: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None, ) -> None: self.app = app - self.dispatch_func = self.dispatch if dispatch_func is None else dispatch_func + self.dispatch_func = self.dispatch if dispatch is None else dispatch def dispatch(self, scope: Scope) -> _HTTPDispatchFlow: raise NotImplementedError # pragma: no cover @@ -36,7 +38,7 @@ class HTTPMiddleware: await maybe_early_response(scope, receive, send) return - response_started = set[bool]() + response_started: set = set() async def wrapped_send(message: Message) -> None: if message["type"] == "http.response.start": @@ -65,6 +67,8 @@ class HTTPMiddleware: try: response = await flow.athrow(exc) + except StopAsyncIteration: + response = None except Exception: # Exception was not handled, or they raised another one. raise @@ -76,6 +80,7 @@ class HTTPMiddleware: ) await response(scope, receive, send) + return if not response_started: raise RuntimeError("No response returned.") diff --git a/tests/middleware/test_http.py b/tests/middleware/test_http.py index abff506b..a36235c5 100644 --- a/tests/middleware/test_http.py +++ b/tests/middleware/test_http.py @@ -1,5 +1,5 @@ import contextvars -from typing import AsyncGenerator +from typing import AsyncGenerator, Optional import pytest @@ -141,23 +141,83 @@ def test_middleware_repr(): assert repr(middleware) == "Middleware(CustomMiddleware)" -def test_fully_evaluated_response(test_client_factory): - # Test for https://github.com/encode/starlette/issues/1022 +def test_early_response(test_client_factory): class CustomMiddleware(HTTPMiddleware): async def dispatch(self, scope: Scope) -> AsyncGenerator[Response, Response]: - yield PlainTextResponse("Custom") + yield Response(status_code=401) app = Starlette(middleware=[Middleware(CustomMiddleware)]) client = test_client_factory(app) - response = client.get("/does_not_exist") - assert response.text == "Custom" + response = client.get("/") + assert response.status_code == 401 + + +def test_too_many_yields(test_client_factory) -> None: + class BadMiddleware(HTTPMiddleware): + async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]: + _ = yield + yield + + app = Starlette(middleware=[Middleware(BadMiddleware)]) + + client = test_client_factory(app) + with pytest.raises(RuntimeError, match="should yield exactly once"): + client.get("/") + + +def test_error_response(test_client_factory) -> None: + class Failed(Exception): + pass + + async def failure(request): + raise Failed() + + class ErrorMiddleware(HTTPMiddleware): + async def dispatch( + self, scope: Scope + ) -> AsyncGenerator[Optional[Response], Response]: + try: + yield None + except Failed: + yield Response("Failed", status_code=500) + + app = Starlette( + routes=[Route("/fail", failure)], + middleware=[Middleware(ErrorMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/fail") + assert response.text == "Failed" + assert response.status_code == 500 + + +def test_no_error_response(test_client_factory) -> None: + class Failed(Exception): + pass + + async def index(request): + raise Failed() + + class BadMiddleware(HTTPMiddleware): + async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]: + try: + yield + except Failed: + pass + + app = Starlette(routes=[Route("/", index)], middleware=[Middleware(BadMiddleware)]) + + client = test_client_factory(app) + with pytest.raises(RuntimeError, match="no response was returned"): + client.get("/") ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") -class CustomMiddlewareWithoutBaseHTTPMiddleware: +class PureASGICustomMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app @@ -167,7 +227,7 @@ class CustomMiddlewareWithoutBaseHTTPMiddleware: assert ctxvar.get() == "set by endpoint" -class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware): +class HTTPCustomMiddleware(HTTPMiddleware): async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]: ctxvar.set("set by middleware") yield @@ -177,8 +237,8 @@ class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware): @pytest.mark.parametrize( "middleware_cls", [ - CustomMiddlewareWithoutBaseHTTPMiddleware, - CustomMiddlewareUsingHTTPMiddleware, + PureASGICustomMiddleware, + HTTPCustomMiddleware, ], ) def test_contextvars(test_client_factory, middleware_cls: type):