]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
More fine-tuning
authorflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 22:31:58 +0000 (00:31 +0200)
committerflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 22:50:04 +0000 (00:50 +0200)
starlette/middleware/http.py
tests/middleware/test_http.py

index a2365eca70960875e6250e66573182f2bd451854..16b7b35cb32653d3d000dcb9db26904f1a6ca67d 100644 (file)
@@ -7,8 +7,19 @@ from ..responses import Response
 from ..types import ASGIApp, Message, Receive, Scope, Send
 
 _DispatchFlow = Union[
+    # Default case:
+    # response = yield
     AsyncGenerator[None, Response],
-    AsyncGenerator[Response, Response],
+    # Early response and/or error handling:
+    # if condition:
+    #     yield Response(...)
+    #     return
+    # try:
+    #     response = yield None
+    # except Exception:
+    #     yield Response(...)
+    # else:
+    #    ...
     AsyncGenerator[Optional[Response], Response],
 ]
 
index b65d5a6d51db58d7b8bd80972fe2c02090210205..5f83d313feb4a9469f5399b386376d1b05acd5f2 100644 (file)
@@ -154,28 +154,40 @@ def test_middleware_repr():
 
 
 def test_early_response(test_client_factory):
+    async def index(request):
+        return PlainTextResponse("Hello, world!")
+
     class CustomMiddleware(HTTPMiddleware):
         async def dispatch(
             self, conn: HTTPConnection
-        ) -> AsyncGenerator[Response, Response]:
-            yield Response(status_code=401)
+        ) -> AsyncGenerator[Optional[Response], Response]:
+            if conn.headers.get("X-Early") == "true":
+                yield Response(status_code=401)
+            else:
+                yield None
 
-    app = Starlette(middleware=[Middleware(CustomMiddleware)])
+    app = Starlette(
+        routes=[Route("/", index)],
+        middleware=[Middleware(CustomMiddleware)],
+    )
 
     client = test_client_factory(app)
     response = client.get("/")
+    assert response.status_code == 200
+    assert response.text == "Hello, world!"
+    response = client.get("/", headers={"X-Early": "true"})
     assert response.status_code == 401
 
 
 def test_too_many_yields(test_client_factory) -> None:
-    class BadMiddleware(HTTPMiddleware):
+    class CustomMiddleware(HTTPMiddleware):
         async def dispatch(
             self, conn: HTTPConnection
         ) -> AsyncGenerator[None, Response]:
             _ = yield
             yield
 
-    app = Starlette(middleware=[Middleware(BadMiddleware)])
+    app = Starlette(middleware=[Middleware(CustomMiddleware)])
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError, match="should yield exactly once"):
@@ -189,7 +201,7 @@ def test_error_response(test_client_factory) -> None:
     async def failure(request):
         raise Failed()
 
-    class ErrorMiddleware(HTTPMiddleware):
+    class CustomMiddleware(HTTPMiddleware):
         async def dispatch(
             self, conn: HTTPConnection
         ) -> AsyncGenerator[Optional[Response], Response]:
@@ -200,7 +212,7 @@ def test_error_response(test_client_factory) -> None:
 
     app = Starlette(
         routes=[Route("/fail", failure)],
-        middleware=[Middleware(ErrorMiddleware)],
+        middleware=[Middleware(CustomMiddleware)],
     )
 
     client = test_client_factory(app)
@@ -216,7 +228,7 @@ def test_no_error_response(test_client_factory) -> None:
     async def index(request):
         raise Failed()
 
-    class BadMiddleware(HTTPMiddleware):
+    class CustomMiddleware(HTTPMiddleware):
         async def dispatch(
             self, conn: HTTPConnection
         ) -> AsyncGenerator[None, Response]:
@@ -225,7 +237,10 @@ def test_no_error_response(test_client_factory) -> None:
             except Failed:
                 pass
 
-    app = Starlette(routes=[Route("/", index)], middleware=[Middleware(BadMiddleware)])
+    app = Starlette(
+        routes=[Route("/", index)],
+        middleware=[Middleware(CustomMiddleware)],
+    )
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError, match="no response was returned"):
@@ -260,9 +275,6 @@ class HTTPCustomMiddleware(HTTPMiddleware):
     ],
 )
 def test_contextvars(test_client_factory, middleware_cls: type):
-    # this has to be an async endpoint because Starlette calls run_in_threadpool
-    # on sync endpoints which has it's own set of peculiarities w.r.t propagating
-    # contextvars (it propagates them forwards but not backwards)
     async def homepage(request):
         assert ctxvar.get() == "set by middleware"
         ctxvar.set("set by endpoint")