]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Address feedback
authorflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 21:29:34 +0000 (23:29 +0200)
committerflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 21:39:24 +0000 (23:39 +0200)
setup.py
starlette/_compat.py
starlette/middleware/http.py
tests/middleware/test_http.py

index f0c160d9b8dd6948d46e3b6a70c0c19945dfde2a..2ee644d42ba7926ffdc36359fdf6c29c2fe984a7 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -39,7 +39,7 @@ setup(
     install_requires=[
         "anyio>=3.4.0,<5",
         "typing_extensions>=3.10.0; python_version < '3.10'",
-        "async_generator; python < '3.10'",
+        "async_generator; python_version < '3.10'",
     ],
     extras_require={
         "full": [
index bba8821f8ae98048af669d9dc48a6d08cb6bc159..b144604031fd16a3f8d9290536a03dfb4abb5f44 100644 (file)
@@ -28,7 +28,6 @@ try:
             data, usedforsecurity=usedforsecurity
         ).hexdigest()
 
-
 except TypeError:  # pragma: no cover
 
     def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str:
@@ -36,6 +35,6 @@ except TypeError:  # pragma: no cover
 
 
 try:
-    from contextlib import aclosing
+    from contextlib import aclosing  # type: ignore[attr-defined]
 except ImportError:  # Python < 3.10
     from async_generator import aclosing  # type: ignore
index f3c07064ae563eaab667fc3d651b7ccff12a4a18..01b9fed80e5220045bd7736811313ba3833caddc 100644 (file)
@@ -1,24 +1,25 @@
-from functools import partial
-from typing import AsyncGenerator, Callable, Optional
+from typing import AsyncGenerator, Callable, Optional, Union
 
 from .._compat import aclosing
 from ..datastructures import MutableHeaders
 from ..responses import Response
 from ..types import ASGIApp, Message, Receive, Scope, Send
 
-HTTPDispatchFlow = AsyncGenerator[Optional[Response], Response]
+_HTTPDispatchFlow = Union[
+    AsyncGenerator[None, Response], AsyncGenerator[ASGIApp, Response]
+]
 
 
 class HTTPMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        dispatch_func: Optional[Callable[[Scope], HTTPDispatchFlow]] = None,
+        dispatch_func: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None,
     ) -> None:
         self.app = app
         self.dispatch_func = self.dispatch if dispatch_func is None else dispatch_func
 
-    def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+    def dispatch(self, scope: Scope) -> _HTTPDispatchFlow:
         raise NotImplementedError  # pragma: no cover
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
@@ -29,20 +30,32 @@ class HTTPMiddleware:
         async with aclosing(self.dispatch(scope)) as flow:
             # Kick the flow until the first `yield`.
             # Might respond early before we call into the app.
-            early_response = await flow.__anext__()
+            maybe_early_response = await flow.__anext__()
 
-            if early_response is not None:
-                await early_response(scope, receive, send)
+            if maybe_early_response is not None:
+                await maybe_early_response(scope, receive, send)
                 return
 
             response_started = set[bool]()
 
-            wrapped_send = partial(
-                self._send,
-                flow=flow,
-                response_started=response_started,
-                send=send,
-            )
+            async def wrapped_send(message: Message) -> None:
+                if message["type"] == "http.response.start":
+                    response_started.add(True)
+
+                    response = Response(status_code=message["status"])
+                    response.raw_headers.clear()
+
+                    try:
+                        await flow.asend(response)
+                    except StopAsyncIteration:
+                        pass
+                    else:
+                        raise RuntimeError("dispatch() should yield exactly once")
+
+                    headers = MutableHeaders(raw=message["headers"])
+                    headers.update(response.headers)
+
+                await send(message)
 
             try:
                 await self.app(scope, receive, wrapped_send)
@@ -66,29 +79,3 @@ class HTTPMiddleware:
 
             if not response_started:
                 raise RuntimeError("No response returned.")
-
-    async def _send(
-        self,
-        message: Message,
-        *,
-        flow: HTTPDispatchFlow,
-        response_started: set,
-        send: Send,
-    ) -> None:
-        if message["type"] == "http.response.start":
-            response_started.add(True)
-
-            response = Response(status_code=message["status"])
-            response.raw_headers.clear()
-
-            try:
-                await flow.asend(response)
-            except StopAsyncIteration:
-                pass
-            else:
-                raise RuntimeError("dispatch() should yield exactly once")
-
-            headers = MutableHeaders(raw=message["headers"])
-            headers.update(response.headers)
-
-        await send(message)
index c0213c76d42a2f57285d2093fbbaadc22ca4159d..abff506b6c7d396f5ee897ce7636ce6aab418461 100644 (file)
@@ -1,18 +1,19 @@
 import contextvars
+from typing import AsyncGenerator
 
 import pytest
 
 from starlette.applications import Starlette
 from starlette.middleware import Middleware
-from starlette.middleware.http import HTTPDispatchFlow, HTTPMiddleware
-from starlette.responses import PlainTextResponse, StreamingResponse
-from starlette.routing import Mount, Route, WebSocketRoute
+from starlette.middleware.http import HTTPMiddleware
+from starlette.responses import PlainTextResponse, Response, StreamingResponse
+from starlette.routing import Route, WebSocketRoute
 from starlette.types import ASGIApp, Receive, Scope, Send
 
 
 class CustomMiddleware(HTTPMiddleware):
-    async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
-        response = yield None
+    async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+        response = yield
         response.headers["Custom-Header"] = "Example"
 
 
@@ -88,19 +89,19 @@ def test_state_data_across_multiple_middlewares(test_client_factory):
     expected_value2 = "bar"
 
     class aMiddleware(HTTPMiddleware):
-        async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+        async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
             scope["state_foo"] = expected_value1
-            yield None
+            yield
 
     class bMiddleware(HTTPMiddleware):
-        async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+        async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
             scope["state_bar"] = expected_value2
-            response = yield None
+            response = yield
             response.headers["X-State-Foo"] = scope["state_foo"]
 
     class cMiddleware(HTTPMiddleware):
-        async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
-            response = yield None
+        async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+            response = yield
             response.headers["X-State-Bar"] = scope["state_bar"]
 
     def homepage(request):
@@ -143,7 +144,7 @@ def test_middleware_repr():
 def test_fully_evaluated_response(test_client_factory):
     # Test for https://github.com/encode/starlette/issues/1022
     class CustomMiddleware(HTTPMiddleware):
-        async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+        async def dispatch(self, scope: Scope) -> AsyncGenerator[Response, Response]:
             yield PlainTextResponse("Custom")
 
     app = Starlette(middleware=[Middleware(CustomMiddleware)])
@@ -153,16 +154,6 @@ def test_fully_evaluated_response(test_client_factory):
     assert response.text == "Custom"
 
 
-def test_exception_on_mounted_apps(test_client_factory):
-    sub_app = Starlette(routes=[Route("/", exc)])
-    app = Starlette(routes=[Mount("/sub", app=sub_app)])
-
-    client = test_client_factory(app)
-    with pytest.raises(Exception) as ctx:
-        client.get("/sub/")
-    assert str(ctx.value) == "Exc"
-
-
 ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
 
 
@@ -177,9 +168,9 @@ class CustomMiddlewareWithoutBaseHTTPMiddleware:
 
 
 class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware):
-    async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+    async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
         ctxvar.set("set by middleware")
-        yield None
+        yield
         assert ctxvar.get() == "set by endpoint"