]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Address, refactor
authorflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 21:59:04 +0000 (23:59 +0200)
committerflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 22:03:00 +0000 (00:03 +0200)
setup.py
starlette/_compat.py
starlette/middleware/http.py
tests/middleware/test_http.py

index 2ee644d42ba7926ffdc36359fdf6c29c2fe984a7..1597ef452a4b18e5ff35779a9cdc329d65812ab0 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -39,7 +39,6 @@ setup(
     install_requires=[
         "anyio>=3.4.0,<5",
         "typing_extensions>=3.10.0; python_version < '3.10'",
-        "async_generator; python_version < '3.10'",
     ],
     extras_require={
         "full": [
index b144604031fd16a3f8d9290536a03dfb4abb5f44..569c06ff207ba70c16524178396541b1c0dbb368 100644 (file)
@@ -1,4 +1,5 @@
 import hashlib
+from typing import Any
 
 __all__ = [
     "md5_hexdigest",
@@ -37,4 +38,13 @@ except TypeError:  # pragma: no cover
 try:
     from contextlib import aclosing  # type: ignore[attr-defined]
 except ImportError:  # Python < 3.10
-    from async_generator import aclosing  # type: ignore
+
+    class aclosing:  # type: ignore
+        def __init__(self, thing: Any) -> None:
+            self.thing = thing
+
+        async def __aenter__(self) -> Any:
+            return self.thing
+
+        async def __aexit__(self, *exc_info: Any) -> None:
+            await self.thing.aclose()
index 01b9fed80e5220045bd7736811313ba3833caddc..ce4a7373bcce629ac56c66932e0c1a7c90aa9654 100644 (file)
@@ -6,7 +6,9 @@ from ..responses import Response
 from ..types import ASGIApp, Message, Receive, Scope, Send
 
 _HTTPDispatchFlow = Union[
-    AsyncGenerator[None, Response], AsyncGenerator[ASGIApp, Response]
+    AsyncGenerator[None, Response],
+    AsyncGenerator[Response, Response],
+    AsyncGenerator[Optional[Response], Response],
 ]
 
 
@@ -14,10 +16,10 @@ class HTTPMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        dispatch_func: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None,
+        dispatch: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None,
     ) -> None:
         self.app = app
-        self.dispatch_func = self.dispatch if dispatch_func is None else dispatch_func
+        self.dispatch_func = self.dispatch if dispatch is None else dispatch
 
     def dispatch(self, scope: Scope) -> _HTTPDispatchFlow:
         raise NotImplementedError  # pragma: no cover
@@ -36,7 +38,7 @@ class HTTPMiddleware:
                 await maybe_early_response(scope, receive, send)
                 return
 
-            response_started = set[bool]()
+            response_started: set = set()
 
             async def wrapped_send(message: Message) -> None:
                 if message["type"] == "http.response.start":
@@ -65,6 +67,8 @@ class HTTPMiddleware:
 
                 try:
                     response = await flow.athrow(exc)
+                except StopAsyncIteration:
+                    response = None
                 except Exception:
                     # Exception was not handled, or they raised another one.
                     raise
@@ -76,6 +80,7 @@ class HTTPMiddleware:
                     )
 
                 await response(scope, receive, send)
+                return
 
             if not response_started:
                 raise RuntimeError("No response returned.")
index abff506b6c7d396f5ee897ce7636ce6aab418461..a36235c5c1b504c350d66a5140087dbd2ca39943 100644 (file)
@@ -1,5 +1,5 @@
 import contextvars
-from typing import AsyncGenerator
+from typing import AsyncGenerator, Optional
 
 import pytest
 
@@ -141,23 +141,83 @@ def test_middleware_repr():
     assert repr(middleware) == "Middleware(CustomMiddleware)"
 
 
-def test_fully_evaluated_response(test_client_factory):
-    # Test for https://github.com/encode/starlette/issues/1022
+def test_early_response(test_client_factory):
     class CustomMiddleware(HTTPMiddleware):
         async def dispatch(self, scope: Scope) -> AsyncGenerator[Response, Response]:
-            yield PlainTextResponse("Custom")
+            yield Response(status_code=401)
 
     app = Starlette(middleware=[Middleware(CustomMiddleware)])
 
     client = test_client_factory(app)
-    response = client.get("/does_not_exist")
-    assert response.text == "Custom"
+    response = client.get("/")
+    assert response.status_code == 401
+
+
+def test_too_many_yields(test_client_factory) -> None:
+    class BadMiddleware(HTTPMiddleware):
+        async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+            _ = yield
+            yield
+
+    app = Starlette(middleware=[Middleware(BadMiddleware)])
+
+    client = test_client_factory(app)
+    with pytest.raises(RuntimeError, match="should yield exactly once"):
+        client.get("/")
+
+
+def test_error_response(test_client_factory) -> None:
+    class Failed(Exception):
+        pass
+
+    async def failure(request):
+        raise Failed()
+
+    class ErrorMiddleware(HTTPMiddleware):
+        async def dispatch(
+            self, scope: Scope
+        ) -> AsyncGenerator[Optional[Response], Response]:
+            try:
+                yield None
+            except Failed:
+                yield Response("Failed", status_code=500)
+
+    app = Starlette(
+        routes=[Route("/fail", failure)],
+        middleware=[Middleware(ErrorMiddleware)],
+    )
+
+    client = test_client_factory(app)
+    response = client.get("/fail")
+    assert response.text == "Failed"
+    assert response.status_code == 500
+
+
+def test_no_error_response(test_client_factory) -> None:
+    class Failed(Exception):
+        pass
+
+    async def index(request):
+        raise Failed()
+
+    class BadMiddleware(HTTPMiddleware):
+        async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+            try:
+                yield
+            except Failed:
+                pass
+
+    app = Starlette(routes=[Route("/", index)], middleware=[Middleware(BadMiddleware)])
+
+    client = test_client_factory(app)
+    with pytest.raises(RuntimeError, match="no response was returned"):
+        client.get("/")
 
 
 ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
 
 
-class CustomMiddlewareWithoutBaseHTTPMiddleware:
+class PureASGICustomMiddleware:
     def __init__(self, app: ASGIApp) -> None:
         self.app = app
 
@@ -167,7 +227,7 @@ class CustomMiddlewareWithoutBaseHTTPMiddleware:
         assert ctxvar.get() == "set by endpoint"
 
 
-class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware):
+class HTTPCustomMiddleware(HTTPMiddleware):
     async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
         ctxvar.set("set by middleware")
         yield
@@ -177,8 +237,8 @@ class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware):
 @pytest.mark.parametrize(
     "middleware_cls",
     [
-        CustomMiddlewareWithoutBaseHTTPMiddleware,
-        CustomMiddlewareUsingHTTPMiddleware,
+        PureASGICustomMiddleware,
+        HTTPCustomMiddleware,
     ],
 )
 def test_contextvars(test_client_factory, middleware_cls: type):