install_requires=[
"anyio>=3.4.0,<5",
"typing_extensions>=3.10.0; python_version < '3.10'",
- "async_generator; python < '3.10'",
+ "async_generator; python_version < '3.10'",
],
extras_require={
"full": [
data, usedforsecurity=usedforsecurity
).hexdigest()
-
except TypeError: # pragma: no cover
def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str:
try:
- from contextlib import aclosing
+ from contextlib import aclosing # type: ignore[attr-defined]
except ImportError: # Python < 3.10
from async_generator import aclosing # type: ignore
-from functools import partial
-from typing import AsyncGenerator, Callable, Optional
+from typing import AsyncGenerator, Callable, Optional, Union
from .._compat import aclosing
from ..datastructures import MutableHeaders
from ..responses import Response
from ..types import ASGIApp, Message, Receive, Scope, Send
-HTTPDispatchFlow = AsyncGenerator[Optional[Response], Response]
+_HTTPDispatchFlow = Union[
+ AsyncGenerator[None, Response], AsyncGenerator[ASGIApp, Response]
+]
class HTTPMiddleware:
def __init__(
self,
app: ASGIApp,
- dispatch_func: Optional[Callable[[Scope], HTTPDispatchFlow]] = None,
+ dispatch_func: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None,
) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch_func is None else dispatch_func
- def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+ def dispatch(self, scope: Scope) -> _HTTPDispatchFlow:
raise NotImplementedError # pragma: no cover
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with aclosing(self.dispatch(scope)) as flow:
# Kick the flow until the first `yield`.
# Might respond early before we call into the app.
- early_response = await flow.__anext__()
+ maybe_early_response = await flow.__anext__()
- if early_response is not None:
- await early_response(scope, receive, send)
+ if maybe_early_response is not None:
+ await maybe_early_response(scope, receive, send)
return
response_started = set[bool]()
- wrapped_send = partial(
- self._send,
- flow=flow,
- response_started=response_started,
- send=send,
- )
+ async def wrapped_send(message: Message) -> None:
+ if message["type"] == "http.response.start":
+ response_started.add(True)
+
+ response = Response(status_code=message["status"])
+ response.raw_headers.clear()
+
+ try:
+ await flow.asend(response)
+ except StopAsyncIteration:
+ pass
+ else:
+ raise RuntimeError("dispatch() should yield exactly once")
+
+ headers = MutableHeaders(raw=message["headers"])
+ headers.update(response.headers)
+
+ await send(message)
try:
await self.app(scope, receive, wrapped_send)
if not response_started:
raise RuntimeError("No response returned.")
-
- async def _send(
- self,
- message: Message,
- *,
- flow: HTTPDispatchFlow,
- response_started: set,
- send: Send,
- ) -> None:
- if message["type"] == "http.response.start":
- response_started.add(True)
-
- response = Response(status_code=message["status"])
- response.raw_headers.clear()
-
- try:
- await flow.asend(response)
- except StopAsyncIteration:
- pass
- else:
- raise RuntimeError("dispatch() should yield exactly once")
-
- headers = MutableHeaders(raw=message["headers"])
- headers.update(response.headers)
-
- await send(message)
import contextvars
+from typing import AsyncGenerator
import pytest
from starlette.applications import Starlette
from starlette.middleware import Middleware
-from starlette.middleware.http import HTTPDispatchFlow, HTTPMiddleware
-from starlette.responses import PlainTextResponse, StreamingResponse
-from starlette.routing import Mount, Route, WebSocketRoute
+from starlette.middleware.http import HTTPMiddleware
+from starlette.responses import PlainTextResponse, Response, StreamingResponse
+from starlette.routing import Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send
class CustomMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
- response = yield None
+ async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+ response = yield
response.headers["Custom-Header"] = "Example"
expected_value2 = "bar"
class aMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+ async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
scope["state_foo"] = expected_value1
- yield None
+ yield
class bMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+ async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
scope["state_bar"] = expected_value2
- response = yield None
+ response = yield
response.headers["X-State-Foo"] = scope["state_foo"]
class cMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
- response = yield None
+ async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+ response = yield
response.headers["X-State-Bar"] = scope["state_bar"]
def homepage(request):
def test_fully_evaluated_response(test_client_factory):
# Test for https://github.com/encode/starlette/issues/1022
class CustomMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+ async def dispatch(self, scope: Scope) -> AsyncGenerator[Response, Response]:
yield PlainTextResponse("Custom")
app = Starlette(middleware=[Middleware(CustomMiddleware)])
assert response.text == "Custom"
-def test_exception_on_mounted_apps(test_client_factory):
- sub_app = Starlette(routes=[Route("/", exc)])
- app = Starlette(routes=[Mount("/sub", app=sub_app)])
-
- client = test_client_factory(app)
- with pytest.raises(Exception) as ctx:
- client.get("/sub/")
- assert str(ctx.value) == "Exc"
-
-
ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+ async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
ctxvar.set("set by middleware")
- yield None
+ yield
assert ctxvar.get() == "set by endpoint"