-from typing import AsyncGenerator, Callable, Optional, Union
+from typing import Any, AsyncGenerator, Callable, Optional, Union
from .._compat import aclosing
-from ..datastructures import MutableHeaders
+from ..datastructures import Headers
from ..requests import HTTPConnection
from ..responses import Response
from ..types import ASGIApp, Message, Receive, Scope, Send
if message["type"] == "http.response.start":
response_started = True
- response = Response(status_code=message["status"])
- response.raw_headers.clear()
+ headers = Headers(raw=message["headers"])
+ response = _StubResponse(
+ status_code=message["status"],
+ media_type=headers.get("content-type"),
+ )
+ response.raw_headers = headers.raw
try:
await flow.asend(response)
else:
raise RuntimeError("dispatch() should yield exactly once")
- headers = MutableHeaders(raw=message["headers"])
- headers.update(response.headers)
- message["headers"] = headers.raw
+ message["headers"] = response.raw_headers
await send(message)
await response(scope, receive, send)
return
+
+
+# This customized stub response helps prevent users from shooting themselves
+# in the foot, doing things that don't actually have any effect.
+
+
+class _StubResponse(Response):
+ def __init__(self, status_code: int, media_type: Optional[str] = None) -> None:
+ self._status_code = status_code
+ self._media_type = media_type
+ self.raw_headers = []
+
+ @property # type: ignore
+ def status_code(self) -> int: # type: ignore
+ return self._status_code
+
+ @status_code.setter
+ def status_code(self, value: Any) -> None:
+ raise RuntimeError(
+ "Setting .status_code in HTTPMiddleware is not supported. "
+ "If you're writing middleware that requires modifying the response "
+ "status code or sending another response altogether, please consider "
+ "writing pure ASGI middleware. "
+ "See: https://starlette.io/middleware/#pure-asgi-middleware"
+ )
+
+ @property # type: ignore
+ def media_type(self) -> Optional[str]: # type: ignore
+ return self._media_type
+
+ @media_type.setter
+ def media_type(self, value: Any) -> None:
+ raise RuntimeError(
+ "Setting .media_type in HTTPMiddleware is not supported, as it has "
+ "no effect. If you do need to tweak the response "
+ "content type, consider: response.headers['Content-Type'] = ..."
+ )
+
+ @property # type: ignore
+ def body(self) -> bytes: # type: ignore
+ raise RuntimeError(
+ "Accessing the response body in HTTPMiddleware is not supported. "
+ "If you're writing middleware that requires peeking into the response "
+ "body, please consider writing pure ASGI middleware and wrapping send(). "
+ "See: https://starlette.io/middleware/#pure-asgi-middleware"
+ )
+
+ @body.setter
+ def body(self, body: bytes) -> None:
+ raise RuntimeError(
+ "Setting the response body in HTTPMiddleware is not supported."
+ "If you're writing middleware that requires modifying the response "
+ "body, please consider writing pure ASGI middleware and wrapping send(). "
+ "See: https://starlette.io/middleware/#pure-asgi-middleware"
+ )
def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
- async def index(request: Request) -> Response:
- return PlainTextResponse("Hello, world!")
+ async def homepage(request: Request) -> Response:
+ return PlainTextResponse("OK")
class CustomMiddleware(HTTPMiddleware):
async def dispatch(
yield None
app = Starlette(
- routes=[Route("/", index)],
+ routes=[Route("/", homepage)],
middleware=[Middleware(CustomMiddleware)],
)
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
- assert response.text == "Hello, world!"
+ assert response.text == "OK"
response = client.get("/", headers={"X-Early": "true"})
assert response.status_code == 401
class Failed(Exception):
pass
- async def index(request: Request) -> Response:
+ async def failure(request: Request) -> Response:
raise Failed()
class CustomMiddleware(HTTPMiddleware):
pass # `yield <response>` expected
app = Starlette(
- routes=[Route("/", index)],
+ routes=[Route("/fail", failure)],
middleware=[Middleware(CustomMiddleware)],
)
client = test_client_factory(app)
with pytest.raises(RuntimeError, match="no response was returned"):
- client.get("/")
+ client.get("/fail")
def test_no_dispatch_given(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
app = Starlette(middleware=[Middleware(HTTPMiddleware)])
+
client = test_client_factory(app)
with pytest.raises(NotImplementedError, match="No dispatch implementation"):
client.get("/")
+
+
+def test_response_stub_attributes(
+ test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
+ async def homepage(request: Request) -> Response:
+ return PlainTextResponse("OK")
+
+ async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+ response = yield
+ if conn.url.path == "/status_code":
+ assert response.status_code == 200
+ response.status_code = 401
+ if conn.url.path == "/media_type":
+ assert response.media_type == "text/plain; charset=utf-8"
+ response.media_type = "text/csv"
+ if conn.url.path == "/body-get":
+ response.body
+ if conn.url.path == "/body-set":
+ response.body = b"changed"
+
+ app = Starlette(
+ routes=[
+ Route("/status_code", homepage),
+ Route("/media_type", homepage),
+ Route("/body-get", homepage),
+ Route("/body-set", homepage),
+ ],
+ middleware=[Middleware(HTTPMiddleware, dispatch=dispatch)],
+ )
+
+ client = test_client_factory(app)
+
+ with pytest.raises(
+ RuntimeError, match="Setting .status_code in HTTPMiddleware is not supported."
+ ):
+ client.get("/status_code")
+
+ with pytest.raises(
+ RuntimeError, match="Setting .media_type in HTTPMiddleware is not supported"
+ ):
+ client.get("/media_type")
+
+ with pytest.raises(
+ RuntimeError,
+ match="Accessing the response body in HTTPMiddleware is not supported",
+ ):
+ client.get("/body-get")
+
+ with pytest.raises(
+ RuntimeError,
+ match="Setting the response body in HTTPMiddleware is not supported",
+ ):
+ client.get("/body-set")