]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add 400 response when `boundary` is missing (#1617)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Thu, 19 May 2022 18:33:03 +0000 (20:33 +0200)
committerGitHub <noreply@github.com>
Thu, 19 May 2022 18:33:03 +0000 (20:33 +0200)
* Add 400 response on MultiParser exceptions

* Add 400 response on MultiParser exceptions [2]

* Add tests and remove name from here

* Remove args from exception

* Move `ExceptionMiddleware` to `starlette.middleware.exceptions`

* add test for deprecation shim

Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
starlette/applications.py
starlette/exceptions.py
starlette/formparsers.py
starlette/middleware/exceptions.py [new file with mode: 0644]
starlette/requests.py
tests/test_exceptions.py
tests/test_formparsers.py

index 8c5154449b488dc085fb976991258d2bbea6977d..c3daade5cf11c9ce51afafe95d0a65ff17da1c0a 100644 (file)
@@ -1,10 +1,10 @@
 import typing
 
 from starlette.datastructures import State, URLPath
-from starlette.exceptions import ExceptionMiddleware
 from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.errors import ServerErrorMiddleware
+from starlette.middleware.exceptions import ExceptionMiddleware
 from starlette.requests import Request
 from starlette.responses import Response
 from starlette.routing import BaseRoute, Router
index 61039c59801fc40c8717e866413325f7db55dcd1..2b5acddb53df956a20b21325a45c16ebd653f505 100644 (file)
@@ -1,11 +1,8 @@
-import asyncio
 import http
 import typing
+import warnings
 
-from starlette.concurrency import run_in_threadpool
-from starlette.requests import Request
-from starlette.responses import PlainTextResponse, Response
-from starlette.types import ASGIApp, Message, Receive, Scope, Send
+__all__ = ("HTTPException",)
 
 
 class HTTPException(Exception):
@@ -26,86 +23,22 @@ class HTTPException(Exception):
         return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
 
 
-class ExceptionMiddleware:
-    def __init__(
-        self,
-        app: ASGIApp,
-        handlers: typing.Optional[
-            typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
-        ] = None,
-        debug: bool = False,
-    ) -> None:
-        self.app = app
-        self.debug = debug  # TODO: We ought to handle 404 cases if debug is set.
-        self._status_handlers: typing.Dict[int, typing.Callable] = {}
-        self._exception_handlers: typing.Dict[
-            typing.Type[Exception], typing.Callable
-        ] = {HTTPException: self.http_exception}
-        if handlers is not None:
-            for key, value in handlers.items():
-                self.add_exception_handler(key, value)
-
-    def add_exception_handler(
-        self,
-        exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
-        handler: typing.Callable[[Request, Exception], Response],
-    ) -> None:
-        if isinstance(exc_class_or_status_code, int):
-            self._status_handlers[exc_class_or_status_code] = handler
-        else:
-            assert issubclass(exc_class_or_status_code, Exception)
-            self._exception_handlers[exc_class_or_status_code] = handler
-
-    def _lookup_exception_handler(
-        self, exc: Exception
-    ) -> typing.Optional[typing.Callable]:
-        for cls in type(exc).__mro__:
-            if cls in self._exception_handlers:
-                return self._exception_handlers[cls]
-        return None
-
-    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
-        if scope["type"] != "http":
-            await self.app(scope, receive, send)
-            return
+__deprecated__ = "ExceptionMiddleware"
 
-        response_started = False
 
-        async def sender(message: Message) -> None:
-            nonlocal response_started
+def __getattr__(name: str) -> typing.Any:  # pragma: no cover
+    if name == __deprecated__:
+        from starlette.middleware.exceptions import ExceptionMiddleware
 
-            if message["type"] == "http.response.start":
-                response_started = True
-            await send(message)
-
-        try:
-            await self.app(scope, receive, sender)
-        except Exception as exc:
-            handler = None
-
-            if isinstance(exc, HTTPException):
-                handler = self._status_handlers.get(exc.status_code)
-
-            if handler is None:
-                handler = self._lookup_exception_handler(exc)
-
-            if handler is None:
-                raise exc
-
-            if response_started:
-                msg = "Caught handled exception, but response already started."
-                raise RuntimeError(msg) from exc
+        warnings.warn(
+            f"{__deprecated__} is deprecated on `starlette.exceptions`. "
+            f"Import it from `starlette.middleware.exceptions` instead.",
+            category=DeprecationWarning,
+            stacklevel=3,
+        )
+        return ExceptionMiddleware
+    raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
 
-            request = Request(scope, receive=receive)
-            if asyncio.iscoroutinefunction(handler):
-                response = await handler(request, exc)
-            else:
-                response = await run_in_threadpool(handler, request, exc)
-            await response(scope, receive, sender)
 
-    def http_exception(self, request: Request, exc: HTTPException) -> Response:
-        if exc.status_code in {204, 304}:
-            return Response(status_code=exc.status_code, headers=exc.headers)
-        return PlainTextResponse(
-            exc.detail, status_code=exc.status_code, headers=exc.headers
-        )
+def __dir__() -> typing.List[str]:
+    return sorted(list(__all__) + [__deprecated__])  # pragma: no cover
index fd19492293750629e35ca58d7e6eb6eeb8674ae1..4cde71b676a142bc9b5c00e4622a8a24a1953b3b 100644 (file)
@@ -38,6 +38,11 @@ def _user_safe_decode(src: bytes, codec: str) -> str:
         return src.decode("latin-1")
 
 
+class MultiPartException(Exception):
+    def __init__(self, message: str) -> None:
+        self.message = message
+
+
 class FormParser:
     def __init__(
         self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
@@ -159,7 +164,10 @@ class MultiPartParser:
         charset = params.get(b"charset", "utf-8")
         if type(charset) == bytes:
             charset = charset.decode("latin-1")
-        boundary = params[b"boundary"]
+        try:
+            boundary = params[b"boundary"]
+        except KeyError:
+            raise MultiPartException("Missing boundary in multipart.")
 
         # Callbacks dictionary.
         callbacks = {
diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py
new file mode 100644 (file)
index 0000000..a3b4633
--- /dev/null
@@ -0,0 +1,93 @@
+import asyncio
+import typing
+
+from starlette.concurrency import run_in_threadpool
+from starlette.exceptions import HTTPException
+from starlette.requests import Request
+from starlette.responses import PlainTextResponse, Response
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
+
+
+class ExceptionMiddleware:
+    def __init__(
+        self,
+        app: ASGIApp,
+        handlers: typing.Optional[
+            typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
+        ] = None,
+        debug: bool = False,
+    ) -> None:
+        self.app = app
+        self.debug = debug  # TODO: We ought to handle 404 cases if debug is set.
+        self._status_handlers: typing.Dict[int, typing.Callable] = {}
+        self._exception_handlers: typing.Dict[
+            typing.Type[Exception], typing.Callable
+        ] = {HTTPException: self.http_exception}
+        if handlers is not None:
+            for key, value in handlers.items():
+                self.add_exception_handler(key, value)
+
+    def add_exception_handler(
+        self,
+        exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
+        handler: typing.Callable[[Request, Exception], Response],
+    ) -> None:
+        if isinstance(exc_class_or_status_code, int):
+            self._status_handlers[exc_class_or_status_code] = handler
+        else:
+            assert issubclass(exc_class_or_status_code, Exception)
+            self._exception_handlers[exc_class_or_status_code] = handler
+
+    def _lookup_exception_handler(
+        self, exc: Exception
+    ) -> typing.Optional[typing.Callable]:
+        for cls in type(exc).__mro__:
+            if cls in self._exception_handlers:
+                return self._exception_handlers[cls]
+        return None
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        if scope["type"] != "http":
+            await self.app(scope, receive, send)
+            return
+
+        response_started = False
+
+        async def sender(message: Message) -> None:
+            nonlocal response_started
+
+            if message["type"] == "http.response.start":
+                response_started = True
+            await send(message)
+
+        try:
+            await self.app(scope, receive, sender)
+        except Exception as exc:
+            handler = None
+
+            if isinstance(exc, HTTPException):
+                handler = self._status_handlers.get(exc.status_code)
+
+            if handler is None:
+                handler = self._lookup_exception_handler(exc)
+
+            if handler is None:
+                raise exc
+
+            if response_started:
+                msg = "Caught handled exception, but response already started."
+                raise RuntimeError(msg) from exc
+
+            request = Request(scope, receive=receive)
+            if asyncio.iscoroutinefunction(handler):
+                response = await handler(request, exc)
+            else:
+                response = await run_in_threadpool(handler, request, exc)
+            await response(scope, receive, sender)
+
+    def http_exception(self, request: Request, exc: HTTPException) -> Response:
+        if exc.status_code in {204, 304}:
+            return Response(status_code=exc.status_code, headers=exc.headers)
+        return PlainTextResponse(
+            exc.detail, status_code=exc.status_code, headers=exc.headers
+        )
index c738ebaa40841a3fef896b0b6a06eb43044f02a3..66c510cfee10ce2408e3f715f5661e68210cbc67 100644 (file)
@@ -6,7 +6,8 @@ from http import cookies as http_cookies
 import anyio
 
 from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
-from starlette.formparsers import FormParser, MultiPartParser
+from starlette.exceptions import HTTPException
+from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
 from starlette.types import Message, Receive, Scope, Send
 
 try:
@@ -250,8 +251,13 @@ class Request(HTTPConnection):
             content_type_header = self.headers.get("Content-Type")
             content_type, options = parse_options_header(content_type_header)
             if content_type == b"multipart/form-data":
-                multipart_parser = MultiPartParser(self.headers, self.stream())
-                self._form = await multipart_parser.parse()
+                try:
+                    multipart_parser = MultiPartParser(self.headers, self.stream())
+                    self._form = await multipart_parser.parse()
+                except MultiPartException as exc:
+                    if "app" in self.scope:
+                        raise HTTPException(status_code=400, detail=exc.message)
+                    raise exc
             elif content_type == b"application/x-www-form-urlencoded":
                 form_parser = FormParser(self.headers, self.stream())
                 self._form = await form_parser.parse()
index 50f677467411fcacdace572d62a33442c67bc41a..9acd421540e6f48ba69c06752f3b3ffbba54d196 100644 (file)
@@ -1,6 +1,9 @@
+import warnings
+
 import pytest
 
-from starlette.exceptions import ExceptionMiddleware, HTTPException
+from starlette.exceptions import HTTPException
+from starlette.middleware.exceptions import ExceptionMiddleware
 from starlette.responses import PlainTextResponse
 from starlette.routing import Route, Router, WebSocketRoute
 
@@ -130,3 +133,16 @@ def test_repr():
     assert repr(CustomHTTPException(500, detail="Something custom")) == (
         "CustomHTTPException(status_code=500, detail='Something custom')"
     )
+
+
+def test_exception_middleware_deprecation() -> None:
+    # this test should be removed once the deprecation shim is removed
+    with pytest.warns(DeprecationWarning):
+        from starlette.exceptions import ExceptionMiddleware  # noqa: F401
+
+    with warnings.catch_warnings():
+        warnings.simplefilter("error")
+        import starlette.exceptions
+
+    with pytest.warns(DeprecationWarning):
+        starlette.exceptions.ExceptionMiddleware
index 3d4b0a199314bc1b0ee54da84fce945b9b440419..6710595297e4da1a5bc700828e40432ce827408e 100644 (file)
@@ -1,11 +1,15 @@
 import os
 import typing
+from contextlib import nullcontext as does_not_raise
 
 import pytest
 
-from starlette.formparsers import UploadFile, _user_safe_decode
+from starlette.applications import Starlette
+from starlette.formparsers import MultiPartException, UploadFile, _user_safe_decode
 from starlette.requests import Request
 from starlette.responses import JSONResponse
+from starlette.routing import Mount
+from starlette.testclient import TestClient
 
 
 class ForceMultipartDict(dict):
@@ -390,10 +394,19 @@ def test_user_safe_decode_ignores_wrong_charset():
     assert result == "abc"
 
 
-def test_missing_boundary_parameter(test_client_factory):
+@pytest.mark.parametrize(
+    "app,expectation",
+    [
+        (app, pytest.raises(MultiPartException)),
+        (Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
+    ],
+)
+def test_missing_boundary_parameter(
+    app, expectation, test_client_factory: typing.Callable[..., TestClient]
+) -> None:
     client = test_client_factory(app)
-    with pytest.raises(KeyError, match="boundary"):
-        client.post(
+    with expectation:
+        res = client.post(
             "/",
             data=(
                 # file
@@ -403,3 +416,5 @@ def test_missing_boundary_parameter(test_client_factory):
             ),
             headers={"Content-Type": "multipart/form-data; charset=utf-8"},
         )
+        assert res.status_code == 400
+        assert res.text == "Missing boundary in multipart."