]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Allow custom middlewares to raise `HTTPException`s and propagate them (#2036)
authorAndy Challis <andy.challis@capgemini.com>
Thu, 25 Aug 2022 21:44:40 +0000 (07:44 +1000)
committerGitHub <noreply@github.com>
Thu, 25 Aug 2022 21:44:40 +0000 (23:44 +0200)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
fastapi/routing.py
tests/test_custom_middleware_exception.py [new file with mode: 0644]

index 8f4d0fa7ef578e97b1dfc7e68bdfb662edcab981..1ac4b388063477b267c6c58027da5fe23eba7f76 100644 (file)
@@ -212,6 +212,8 @@ def get_request_handler(
             raise RequestValidationError(
                 [ErrorWrapper(e, ("body", e.pos))], body=e.doc
             ) from e
+        except HTTPException:
+            raise
         except Exception as e:
             raise HTTPException(
                 status_code=400, detail="There was an error parsing the body"
diff --git a/tests/test_custom_middleware_exception.py b/tests/test_custom_middleware_exception.py
new file mode 100644 (file)
index 0000000..d9b81e7
--- /dev/null
@@ -0,0 +1,95 @@
+from pathlib import Path
+from typing import Optional
+
+from fastapi import APIRouter, FastAPI, File, UploadFile
+from fastapi.exceptions import HTTPException
+from fastapi.testclient import TestClient
+
+app = FastAPI()
+
+router = APIRouter()
+
+
+class ContentSizeLimitMiddleware:
+    """Content size limiting middleware for ASGI applications
+    Args:
+      app (ASGI application): ASGI application
+      max_content_size (optional): the maximum content size allowed in bytes, None for no limit
+    """
+
+    def __init__(self, app: APIRouter, max_content_size: Optional[int] = None):
+        self.app = app
+        self.max_content_size = max_content_size
+
+    def receive_wrapper(self, receive):
+        received = 0
+
+        async def inner():
+            nonlocal received
+            message = await receive()
+            if message["type"] != "http.request":
+                return message  # pragma: no cover
+
+            body_len = len(message.get("body", b""))
+            received += body_len
+            if received > self.max_content_size:
+                raise HTTPException(
+                    422,
+                    detail={
+                        "name": "ContentSizeLimitExceeded",
+                        "code": 999,
+                        "message": "File limit exceeded",
+                    },
+                )
+            return message
+
+        return inner
+
+    async def __call__(self, scope, receive, send):
+        if scope["type"] != "http" or self.max_content_size is None:
+            await self.app(scope, receive, send)
+            return
+
+        wrapper = self.receive_wrapper(receive)
+        await self.app(scope, wrapper, send)
+
+
+@router.post("/middleware")
+def run_middleware(file: UploadFile = File(..., description="Big File")):
+    return {"message": "OK"}
+
+
+app.include_router(router)
+app.add_middleware(ContentSizeLimitMiddleware, max_content_size=2**8)
+
+
+client = TestClient(app)
+
+
+def test_custom_middleware_exception(tmp_path: Path):
+    default_pydantic_max_size = 2**16
+    path = tmp_path / "test.txt"
+    path.write_bytes(b"x" * (default_pydantic_max_size + 1))
+
+    with client:
+        with open(path, "rb") as file:
+            response = client.post("/middleware", files={"file": file})
+        assert response.status_code == 422, response.text
+        assert response.json() == {
+            "detail": {
+                "name": "ContentSizeLimitExceeded",
+                "code": 999,
+                "message": "File limit exceeded",
+            }
+        }
+
+
+def test_custom_middleware_exception_not_raised(tmp_path: Path):
+    path = tmp_path / "test.txt"
+    path.write_bytes(b"<file content>")
+
+    with client:
+        with open(path, "rb") as file:
+            response = client.post("/middleware", files={"file": file})
+        assert response.status_code == 200, response.text
+        assert response.json() == {"message": "OK"}