def test_early_response(test_client_factory):
+ async def index(request):
+ return PlainTextResponse("Hello, world!")
+
class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
- ) -> AsyncGenerator[Response, Response]:
- yield Response(status_code=401)
+ ) -> AsyncGenerator[Optional[Response], Response]:
+ if conn.headers.get("X-Early") == "true":
+ yield Response(status_code=401)
+ else:
+ yield None
- app = Starlette(middleware=[Middleware(CustomMiddleware)])
+ app = Starlette(
+ routes=[Route("/", index)],
+ middleware=[Middleware(CustomMiddleware)],
+ )
client = test_client_factory(app)
response = client.get("/")
+ assert response.status_code == 200
+ assert response.text == "Hello, world!"
+ response = client.get("/", headers={"X-Early": "true"})
assert response.status_code == 401
def test_too_many_yields(test_client_factory) -> None:
- class BadMiddleware(HTTPMiddleware):
+ class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[None, Response]:
_ = yield
yield
- app = Starlette(middleware=[Middleware(BadMiddleware)])
+ app = Starlette(middleware=[Middleware(CustomMiddleware)])
client = test_client_factory(app)
with pytest.raises(RuntimeError, match="should yield exactly once"):
async def failure(request):
raise Failed()
- class ErrorMiddleware(HTTPMiddleware):
+ class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[Optional[Response], Response]:
app = Starlette(
routes=[Route("/fail", failure)],
- middleware=[Middleware(ErrorMiddleware)],
+ middleware=[Middleware(CustomMiddleware)],
)
client = test_client_factory(app)
async def index(request):
raise Failed()
- class BadMiddleware(HTTPMiddleware):
+ class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[None, Response]:
except Failed:
pass
- app = Starlette(routes=[Route("/", index)], middleware=[Middleware(BadMiddleware)])
+ app = Starlette(
+ routes=[Route("/", index)],
+ middleware=[Middleware(CustomMiddleware)],
+ )
client = test_client_factory(app)
with pytest.raises(RuntimeError, match="no response was returned"):
],
)
def test_contextvars(test_client_factory, middleware_cls: type):
- # this has to be an async endpoint because Starlette calls run_in_threadpool
- # on sync endpoints which has it's own set of peculiarities w.r.t propagating
- # contextvars (it propagates them forwards but not backwards)
async def homepage(request):
assert ctxvar.get() == "set by middleware"
ctxvar.set("set by endpoint")