]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Prevent BaseHTTPMiddleware from hiding errors of StreamingResponse (#1459)
authorOleksandr Fedorov <2283679+o-fedorov@users.noreply.github.com>
Mon, 31 Jan 2022 10:12:15 +0000 (12:12 +0200)
committerGitHub <noreply@github.com>
Mon, 31 Jan 2022 10:12:15 +0000 (11:12 +0100)
* Prevent BaseHTTPMiddleware from hiding errors of StreamingResponse

* Apply notes from PR:

* remove `nonlocal app_exc`;
* add extra test.

* Fix formatting

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
starlette/middleware/base.py
tests/middleware/test_base.py

index 423f40777ebfb8ae86b9cddbf63b7240e4c6427f..bfb4a54a48db9f9aba15a200ad936834f7f05748 100644 (file)
@@ -52,6 +52,9 @@ class BaseHTTPMiddleware:
                         assert message["type"] == "http.response.body"
                         yield message.get("body", b"")
 
+                if app_exc is not None:
+                    raise app_exc
+
             response = StreamingResponse(
                 status_code=message["status"], content=body_stream()
             )
index c6bfd49fc89826af064106d034397faacded59d2..32468dcd28c8234f6e073f9e0a75b5ac403e0a53 100644 (file)
@@ -3,7 +3,7 @@ import pytest
 from starlette.applications import Starlette
 from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.responses import PlainTextResponse
+from starlette.responses import PlainTextResponse, StreamingResponse
 from starlette.routing import Route
 
 
@@ -28,6 +28,16 @@ def exc(request):
     raise Exception("Exc")
 
 
+@app.route("/exc-stream")
+def exc_stream(request):
+    return StreamingResponse(_generate_faulty_stream())
+
+
+def _generate_faulty_stream():
+    yield b"Ok"
+    raise Exception("Faulty Stream")
+
+
 @app.route("/no-response")
 class NoResponse:
     def __init__(self, scope, receive, send):
@@ -56,6 +66,10 @@ def test_custom_middleware(test_client_factory):
         response = client.get("/exc")
     assert str(ctx.value) == "Exc"
 
+    with pytest.raises(Exception) as ctx:
+        response = client.get("/exc-stream")
+    assert str(ctx.value) == "Faulty Stream"
+
     with pytest.raises(RuntimeError):
         response = client.get("/no-response")
 
@@ -158,3 +172,13 @@ def test_fully_evaluated_response(test_client_factory):
     client = test_client_factory(app)
     response = client.get("/does_not_exist")
     assert response.text == "Custom"
+
+
+def test_exception_on_mounted_apps(test_client_factory):
+    sub_app = Starlette(routes=[Route("/", exc)])
+    app.mount("/sub", sub_app)
+
+    client = test_client_factory(app)
+    with pytest.raises(Exception) as ctx:
+        client.get("/sub/")
+    assert str(ctx.value) == "Exc"