]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Attempt controlling response attrs
authorflorimondmanca <florimond.manca@protonmail.com>
Wed, 15 Jun 2022 21:51:52 +0000 (23:51 +0200)
committerflorimondmanca <florimond.manca@protonmail.com>
Wed, 15 Jun 2022 21:55:14 +0000 (23:55 +0200)
starlette/middleware/http.py
tests/middleware/test_http.py

index aa16752639b4d1eb1207b2b07c684623568c6540..4341458ccce25c82f284334935f762f18b6616d3 100644 (file)
@@ -1,7 +1,7 @@
-from typing import AsyncGenerator, Callable, Optional, Union
+from typing import Any, AsyncGenerator, Callable, Optional, Union
 
 from .._compat import aclosing
-from ..datastructures import MutableHeaders
+from ..datastructures import Headers
 from ..requests import HTTPConnection
 from ..responses import Response
 from ..types import ASGIApp, Message, Receive, Scope, Send
@@ -73,8 +73,12 @@ class HTTPMiddleware:
                 if message["type"] == "http.response.start":
                     response_started = True
 
-                    response = Response(status_code=message["status"])
-                    response.raw_headers.clear()
+                    headers = Headers(raw=message["headers"])
+                    response = _StubResponse(
+                        status_code=message["status"],
+                        media_type=headers.get("content-type"),
+                    )
+                    response.raw_headers = headers.raw
 
                     try:
                         await flow.asend(response)
@@ -83,9 +87,7 @@ class HTTPMiddleware:
                     else:
                         raise RuntimeError("dispatch() should yield exactly once")
 
-                    headers = MutableHeaders(raw=message["headers"])
-                    headers.update(response.headers)
-                    message["headers"] = headers.raw
+                    message["headers"] = response.raw_headers
 
                 await send(message)
 
@@ -111,3 +113,58 @@ class HTTPMiddleware:
 
                 await response(scope, receive, send)
                 return
+
+
+# This customized stub response helps prevent users from shooting themselves
+# in the foot, doing things that don't actually have any effect.
+
+
+class _StubResponse(Response):
+    def __init__(self, status_code: int, media_type: Optional[str] = None) -> None:
+        self._status_code = status_code
+        self._media_type = media_type
+        self.raw_headers = []
+
+    @property  # type: ignore
+    def status_code(self) -> int:  # type: ignore
+        return self._status_code
+
+    @status_code.setter
+    def status_code(self, value: Any) -> None:
+        raise RuntimeError(
+            "Setting .status_code in HTTPMiddleware is not supported. "
+            "If you're writing middleware that requires modifying the response "
+            "status code or sending another response altogether, please consider "
+            "writing pure ASGI middleware. "
+            "See: https://starlette.io/middleware/#pure-asgi-middleware"
+        )
+
+    @property  # type: ignore
+    def media_type(self) -> Optional[str]:  # type: ignore
+        return self._media_type
+
+    @media_type.setter
+    def media_type(self, value: Any) -> None:
+        raise RuntimeError(
+            "Setting .media_type in HTTPMiddleware is not supported, as it has "
+            "no effect. If you do need to tweak the response "
+            "content type, consider: response.headers['Content-Type'] = ..."
+        )
+
+    @property  # type: ignore
+    def body(self) -> bytes:  # type: ignore
+        raise RuntimeError(
+            "Accessing the response body in HTTPMiddleware is not supported. "
+            "If you're writing middleware that requires peeking into the response "
+            "body, please consider writing pure ASGI middleware and wrapping send(). "
+            "See: https://starlette.io/middleware/#pure-asgi-middleware"
+        )
+
+    @body.setter
+    def body(self, body: bytes) -> None:
+        raise RuntimeError(
+            "Setting the response body in HTTPMiddleware is not supported."
+            "If you're writing middleware that requires modifying the response "
+            "body, please consider writing pure ASGI middleware and wrapping send(). "
+            "See: https://starlette.io/middleware/#pure-asgi-middleware"
+        )
index cbe1c3ac73b0fef1a8d1103b7fd9fbfa6027367c..08162503081ecd0bb86096220cb2ae3ff74b413f 100644 (file)
@@ -127,8 +127,8 @@ def test_too_many_yields(test_client_factory: Callable[[ASGIApp], TestClient]) -
 
 
 def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
-    async def index(request: Request) -> Response:
-        return PlainTextResponse("Hello, world!")
+    async def homepage(request: Request) -> Response:
+        return PlainTextResponse("OK")
 
     class CustomMiddleware(HTTPMiddleware):
         async def dispatch(
@@ -140,14 +140,14 @@ def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) ->
                 yield None
 
     app = Starlette(
-        routes=[Route("/", index)],
+        routes=[Route("/", homepage)],
         middleware=[Middleware(CustomMiddleware)],
     )
 
     client = test_client_factory(app)
     response = client.get("/")
     assert response.status_code == 200
-    assert response.text == "Hello, world!"
+    assert response.text == "OK"
     response = client.get("/", headers={"X-Early": "true"})
     assert response.status_code == 401
 
@@ -202,7 +202,7 @@ def test_error_handling_must_send_response(
     class Failed(Exception):
         pass
 
-    async def index(request: Request) -> Response:
+    async def failure(request: Request) -> Response:
         raise Failed()
 
     class CustomMiddleware(HTTPMiddleware):
@@ -215,19 +215,74 @@ def test_error_handling_must_send_response(
                 pass  # `yield <response>` expected
 
     app = Starlette(
-        routes=[Route("/", index)],
+        routes=[Route("/fail", failure)],
         middleware=[Middleware(CustomMiddleware)],
     )
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError, match="no response was returned"):
-        client.get("/")
+        client.get("/fail")
 
 
 def test_no_dispatch_given(
     test_client_factory: Callable[[ASGIApp], TestClient]
 ) -> None:
     app = Starlette(middleware=[Middleware(HTTPMiddleware)])
+
     client = test_client_factory(app)
     with pytest.raises(NotImplementedError, match="No dispatch implementation"):
         client.get("/")
+
+
+def test_response_stub_attributes(
+    test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
+    async def homepage(request: Request) -> Response:
+        return PlainTextResponse("OK")
+
+    async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+        response = yield
+        if conn.url.path == "/status_code":
+            assert response.status_code == 200
+            response.status_code = 401
+        if conn.url.path == "/media_type":
+            assert response.media_type == "text/plain; charset=utf-8"
+            response.media_type = "text/csv"
+        if conn.url.path == "/body-get":
+            response.body
+        if conn.url.path == "/body-set":
+            response.body = b"changed"
+
+    app = Starlette(
+        routes=[
+            Route("/status_code", homepage),
+            Route("/media_type", homepage),
+            Route("/body-get", homepage),
+            Route("/body-set", homepage),
+        ],
+        middleware=[Middleware(HTTPMiddleware, dispatch=dispatch)],
+    )
+
+    client = test_client_factory(app)
+
+    with pytest.raises(
+        RuntimeError, match="Setting .status_code in HTTPMiddleware is not supported."
+    ):
+        client.get("/status_code")
+
+    with pytest.raises(
+        RuntimeError, match="Setting .media_type in HTTPMiddleware is not supported"
+    ):
+        client.get("/media_type")
+
+    with pytest.raises(
+        RuntimeError,
+        match="Accessing the response body in HTTPMiddleware is not supported",
+    ):
+        client.get("/body-get")
+
+    with pytest.raises(
+        RuntimeError,
+        match="Setting the response body in HTTPMiddleware is not supported",
+    ):
+        client.get("/body-set")