From: Marcelo Trylesinski Date: Thu, 19 May 2022 18:33:03 +0000 (+0200) Subject: Add 400 response when `boundary` is missing (#1617) X-Git-Tag: 0.20.1~9 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=3453fd6b6259393f3744815b49af297c934f32d0;p=thirdparty%2Fstarlette.git Add 400 response when `boundary` is missing (#1617) * Add 400 response on MultiParser exceptions * Add 400 response on MultiParser exceptions [2] * Add tests and remove name from here * Remove args from exception * Move `ExceptionMiddleware` to `starlette.middleware.exceptions` * add test for deprecation shim Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> --- diff --git a/starlette/applications.py b/starlette/applications.py index 8c515444..c3daade5 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,10 +1,10 @@ import typing from starlette.datastructures import State, URLPath -from starlette.exceptions import ExceptionMiddleware from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware +from starlette.middleware.exceptions import ExceptionMiddleware from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Router diff --git a/starlette/exceptions.py b/starlette/exceptions.py index 61039c59..2b5acddb 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -1,11 +1,8 @@ -import asyncio import http import typing +import warnings -from starlette.concurrency import run_in_threadpool -from starlette.requests import Request -from starlette.responses import PlainTextResponse, Response -from starlette.types import ASGIApp, Message, Receive, Scope, Send +__all__ = ("HTTPException",) class HTTPException(Exception): @@ -26,86 +23,22 @@ class HTTPException(Exception): return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})" -class ExceptionMiddleware: - def __init__( - self, - app: ASGIApp, - handlers: typing.Optional[ - typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] - ] = None, - debug: bool = False, - ) -> None: - self.app = app - self.debug = debug # TODO: We ought to handle 404 cases if debug is set. - self._status_handlers: typing.Dict[int, typing.Callable] = {} - self._exception_handlers: typing.Dict[ - typing.Type[Exception], typing.Callable - ] = {HTTPException: self.http_exception} - if handlers is not None: - for key, value in handlers.items(): - self.add_exception_handler(key, value) - - def add_exception_handler( - self, - exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], - handler: typing.Callable[[Request, Exception], Response], - ) -> None: - if isinstance(exc_class_or_status_code, int): - self._status_handlers[exc_class_or_status_code] = handler - else: - assert issubclass(exc_class_or_status_code, Exception) - self._exception_handlers[exc_class_or_status_code] = handler - - def _lookup_exception_handler( - self, exc: Exception - ) -> typing.Optional[typing.Callable]: - for cls in type(exc).__mro__: - if cls in self._exception_handlers: - return self._exception_handlers[cls] - return None - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] != "http": - await self.app(scope, receive, send) - return +__deprecated__ = "ExceptionMiddleware" - response_started = False - async def sender(message: Message) -> None: - nonlocal response_started +def __getattr__(name: str) -> typing.Any: # pragma: no cover + if name == __deprecated__: + from starlette.middleware.exceptions import ExceptionMiddleware - if message["type"] == "http.response.start": - response_started = True - await send(message) - - try: - await self.app(scope, receive, sender) - except Exception as exc: - handler = None - - if isinstance(exc, HTTPException): - handler = self._status_handlers.get(exc.status_code) - - if handler is None: - handler = self._lookup_exception_handler(exc) - - if handler is None: - raise exc - - if response_started: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc + warnings.warn( + f"{__deprecated__} is deprecated on `starlette.exceptions`. " + f"Import it from `starlette.middleware.exceptions` instead.", + category=DeprecationWarning, + stacklevel=3, + ) + return ExceptionMiddleware + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") - request = Request(scope, receive=receive) - if asyncio.iscoroutinefunction(handler): - response = await handler(request, exc) - else: - response = await run_in_threadpool(handler, request, exc) - await response(scope, receive, sender) - def http_exception(self, request: Request, exc: HTTPException) -> Response: - if exc.status_code in {204, 304}: - return Response(status_code=exc.status_code, headers=exc.headers) - return PlainTextResponse( - exc.detail, status_code=exc.status_code, headers=exc.headers - ) +def __dir__() -> typing.List[str]: + return sorted(list(__all__) + [__deprecated__]) # pragma: no cover diff --git a/starlette/formparsers.py b/starlette/formparsers.py index fd194922..4cde71b6 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -38,6 +38,11 @@ def _user_safe_decode(src: bytes, codec: str) -> str: return src.decode("latin-1") +class MultiPartException(Exception): + def __init__(self, message: str) -> None: + self.message = message + + class FormParser: def __init__( self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] @@ -159,7 +164,10 @@ class MultiPartParser: charset = params.get(b"charset", "utf-8") if type(charset) == bytes: charset = charset.decode("latin-1") - boundary = params[b"boundary"] + try: + boundary = params[b"boundary"] + except KeyError: + raise MultiPartException("Missing boundary in multipart.") # Callbacks dictionary. callbacks = { diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py new file mode 100644 index 00000000..a3b4633d --- /dev/null +++ b/starlette/middleware/exceptions.py @@ -0,0 +1,93 @@ +import asyncio +import typing + +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + + +class ExceptionMiddleware: + def __init__( + self, + app: ASGIApp, + handlers: typing.Optional[ + typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] + ] = None, + debug: bool = False, + ) -> None: + self.app = app + self.debug = debug # TODO: We ought to handle 404 cases if debug is set. + self._status_handlers: typing.Dict[int, typing.Callable] = {} + self._exception_handlers: typing.Dict[ + typing.Type[Exception], typing.Callable + ] = {HTTPException: self.http_exception} + if handlers is not None: + for key, value in handlers.items(): + self.add_exception_handler(key, value) + + def add_exception_handler( + self, + exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], + handler: typing.Callable[[Request, Exception], Response], + ) -> None: + if isinstance(exc_class_or_status_code, int): + self._status_handlers[exc_class_or_status_code] = handler + else: + assert issubclass(exc_class_or_status_code, Exception) + self._exception_handlers[exc_class_or_status_code] = handler + + def _lookup_exception_handler( + self, exc: Exception + ) -> typing.Optional[typing.Callable]: + for cls in type(exc).__mro__: + if cls in self._exception_handlers: + return self._exception_handlers[cls] + return None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + response_started = False + + async def sender(message: Message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + try: + await self.app(scope, receive, sender) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = self._status_handlers.get(exc.status_code) + + if handler is None: + handler = self._lookup_exception_handler(exc) + + if handler is None: + raise exc + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc + + request = Request(scope, receive=receive) + if asyncio.iscoroutinefunction(handler): + response = await handler(request, exc) + else: + response = await run_in_threadpool(handler, request, exc) + await response(scope, receive, sender) + + def http_exception(self, request: Request, exc: HTTPException) -> Response: + if exc.status_code in {204, 304}: + return Response(status_code=exc.status_code, headers=exc.headers) + return PlainTextResponse( + exc.detail, status_code=exc.status_code, headers=exc.headers + ) diff --git a/starlette/requests.py b/starlette/requests.py index c738ebaa..66c510cf 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -6,7 +6,8 @@ from http import cookies as http_cookies import anyio from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State -from starlette.formparsers import FormParser, MultiPartParser +from starlette.exceptions import HTTPException +from starlette.formparsers import FormParser, MultiPartException, MultiPartParser from starlette.types import Message, Receive, Scope, Send try: @@ -250,8 +251,13 @@ class Request(HTTPConnection): content_type_header = self.headers.get("Content-Type") content_type, options = parse_options_header(content_type_header) if content_type == b"multipart/form-data": - multipart_parser = MultiPartParser(self.headers, self.stream()) - self._form = await multipart_parser.parse() + try: + multipart_parser = MultiPartParser(self.headers, self.stream()) + self._form = await multipart_parser.parse() + except MultiPartException as exc: + if "app" in self.scope: + raise HTTPException(status_code=400, detail=exc.message) + raise exc elif content_type == b"application/x-www-form-urlencoded": form_parser = FormParser(self.headers, self.stream()) self._form = await form_parser.parse() diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 50f67746..9acd4215 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,6 +1,9 @@ +import warnings + import pytest -from starlette.exceptions import ExceptionMiddleware, HTTPException +from starlette.exceptions import HTTPException +from starlette.middleware.exceptions import ExceptionMiddleware from starlette.responses import PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute @@ -130,3 +133,16 @@ def test_repr(): assert repr(CustomHTTPException(500, detail="Something custom")) == ( "CustomHTTPException(status_code=500, detail='Something custom')" ) + + +def test_exception_middleware_deprecation() -> None: + # this test should be removed once the deprecation shim is removed + with pytest.warns(DeprecationWarning): + from starlette.exceptions import ExceptionMiddleware # noqa: F401 + + with warnings.catch_warnings(): + warnings.simplefilter("error") + import starlette.exceptions + + with pytest.warns(DeprecationWarning): + starlette.exceptions.ExceptionMiddleware diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 3d4b0a19..67105952 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -1,11 +1,15 @@ import os import typing +from contextlib import nullcontext as does_not_raise import pytest -from starlette.formparsers import UploadFile, _user_safe_decode +from starlette.applications import Starlette +from starlette.formparsers import MultiPartException, UploadFile, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse +from starlette.routing import Mount +from starlette.testclient import TestClient class ForceMultipartDict(dict): @@ -390,10 +394,19 @@ def test_user_safe_decode_ignores_wrong_charset(): assert result == "abc" -def test_missing_boundary_parameter(test_client_factory): +@pytest.mark.parametrize( + "app,expectation", + [ + (app, pytest.raises(MultiPartException)), + (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), + ], +) +def test_missing_boundary_parameter( + app, expectation, test_client_factory: typing.Callable[..., TestClient] +) -> None: client = test_client_factory(app) - with pytest.raises(KeyError, match="boundary"): - client.post( + with expectation: + res = client.post( "/", data=( # file @@ -403,3 +416,5 @@ def test_missing_boundary_parameter(test_client_factory): ), headers={"Content-Type": "multipart/form-data; charset=utf-8"}, ) + assert res.status_code == 400 + assert res.text == "Missing boundary in multipart."