import typing
from starlette.datastructures import State, URLPath
-from starlette.exceptions import ExceptionMiddleware
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
+from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
-import asyncio
import http
import typing
+import warnings
-from starlette.concurrency import run_in_threadpool
-from starlette.requests import Request
-from starlette.responses import PlainTextResponse, Response
-from starlette.types import ASGIApp, Message, Receive, Scope, Send
+__all__ = ("HTTPException",)
class HTTPException(Exception):
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
-class ExceptionMiddleware:
- def __init__(
- self,
- app: ASGIApp,
- handlers: typing.Optional[
- typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
- ] = None,
- debug: bool = False,
- ) -> None:
- self.app = app
- self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
- self._status_handlers: typing.Dict[int, typing.Callable] = {}
- self._exception_handlers: typing.Dict[
- typing.Type[Exception], typing.Callable
- ] = {HTTPException: self.http_exception}
- if handlers is not None:
- for key, value in handlers.items():
- self.add_exception_handler(key, value)
-
- def add_exception_handler(
- self,
- exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
- handler: typing.Callable[[Request, Exception], Response],
- ) -> None:
- if isinstance(exc_class_or_status_code, int):
- self._status_handlers[exc_class_or_status_code] = handler
- else:
- assert issubclass(exc_class_or_status_code, Exception)
- self._exception_handlers[exc_class_or_status_code] = handler
-
- def _lookup_exception_handler(
- self, exc: Exception
- ) -> typing.Optional[typing.Callable]:
- for cls in type(exc).__mro__:
- if cls in self._exception_handlers:
- return self._exception_handlers[cls]
- return None
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if scope["type"] != "http":
- await self.app(scope, receive, send)
- return
+__deprecated__ = "ExceptionMiddleware"
- response_started = False
- async def sender(message: Message) -> None:
- nonlocal response_started
+def __getattr__(name: str) -> typing.Any: # pragma: no cover
+ if name == __deprecated__:
+ from starlette.middleware.exceptions import ExceptionMiddleware
- if message["type"] == "http.response.start":
- response_started = True
- await send(message)
-
- try:
- await self.app(scope, receive, sender)
- except Exception as exc:
- handler = None
-
- if isinstance(exc, HTTPException):
- handler = self._status_handlers.get(exc.status_code)
-
- if handler is None:
- handler = self._lookup_exception_handler(exc)
-
- if handler is None:
- raise exc
-
- if response_started:
- msg = "Caught handled exception, but response already started."
- raise RuntimeError(msg) from exc
+ warnings.warn(
+ f"{__deprecated__} is deprecated on `starlette.exceptions`. "
+ f"Import it from `starlette.middleware.exceptions` instead.",
+ category=DeprecationWarning,
+ stacklevel=3,
+ )
+ return ExceptionMiddleware
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
- request = Request(scope, receive=receive)
- if asyncio.iscoroutinefunction(handler):
- response = await handler(request, exc)
- else:
- response = await run_in_threadpool(handler, request, exc)
- await response(scope, receive, sender)
- def http_exception(self, request: Request, exc: HTTPException) -> Response:
- if exc.status_code in {204, 304}:
- return Response(status_code=exc.status_code, headers=exc.headers)
- return PlainTextResponse(
- exc.detail, status_code=exc.status_code, headers=exc.headers
- )
+def __dir__() -> typing.List[str]:
+ return sorted(list(__all__) + [__deprecated__]) # pragma: no cover
return src.decode("latin-1")
+class MultiPartException(Exception):
+ def __init__(self, message: str) -> None:
+ self.message = message
+
+
class FormParser:
def __init__(
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
charset = params.get(b"charset", "utf-8")
if type(charset) == bytes:
charset = charset.decode("latin-1")
- boundary = params[b"boundary"]
+ try:
+ boundary = params[b"boundary"]
+ except KeyError:
+ raise MultiPartException("Missing boundary in multipart.")
# Callbacks dictionary.
callbacks = {
--- /dev/null
+import asyncio
+import typing
+
+from starlette.concurrency import run_in_threadpool
+from starlette.exceptions import HTTPException
+from starlette.requests import Request
+from starlette.responses import PlainTextResponse, Response
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
+
+
+class ExceptionMiddleware:
+ def __init__(
+ self,
+ app: ASGIApp,
+ handlers: typing.Optional[
+ typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
+ ] = None,
+ debug: bool = False,
+ ) -> None:
+ self.app = app
+ self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
+ self._status_handlers: typing.Dict[int, typing.Callable] = {}
+ self._exception_handlers: typing.Dict[
+ typing.Type[Exception], typing.Callable
+ ] = {HTTPException: self.http_exception}
+ if handlers is not None:
+ for key, value in handlers.items():
+ self.add_exception_handler(key, value)
+
+ def add_exception_handler(
+ self,
+ exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
+ handler: typing.Callable[[Request, Exception], Response],
+ ) -> None:
+ if isinstance(exc_class_or_status_code, int):
+ self._status_handlers[exc_class_or_status_code] = handler
+ else:
+ assert issubclass(exc_class_or_status_code, Exception)
+ self._exception_handlers[exc_class_or_status_code] = handler
+
+ def _lookup_exception_handler(
+ self, exc: Exception
+ ) -> typing.Optional[typing.Callable]:
+ for cls in type(exc).__mro__:
+ if cls in self._exception_handlers:
+ return self._exception_handlers[cls]
+ return None
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] != "http":
+ await self.app(scope, receive, send)
+ return
+
+ response_started = False
+
+ async def sender(message: Message) -> None:
+ nonlocal response_started
+
+ if message["type"] == "http.response.start":
+ response_started = True
+ await send(message)
+
+ try:
+ await self.app(scope, receive, sender)
+ except Exception as exc:
+ handler = None
+
+ if isinstance(exc, HTTPException):
+ handler = self._status_handlers.get(exc.status_code)
+
+ if handler is None:
+ handler = self._lookup_exception_handler(exc)
+
+ if handler is None:
+ raise exc
+
+ if response_started:
+ msg = "Caught handled exception, but response already started."
+ raise RuntimeError(msg) from exc
+
+ request = Request(scope, receive=receive)
+ if asyncio.iscoroutinefunction(handler):
+ response = await handler(request, exc)
+ else:
+ response = await run_in_threadpool(handler, request, exc)
+ await response(scope, receive, sender)
+
+ def http_exception(self, request: Request, exc: HTTPException) -> Response:
+ if exc.status_code in {204, 304}:
+ return Response(status_code=exc.status_code, headers=exc.headers)
+ return PlainTextResponse(
+ exc.detail, status_code=exc.status_code, headers=exc.headers
+ )
import anyio
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
-from starlette.formparsers import FormParser, MultiPartParser
+from starlette.exceptions import HTTPException
+from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
from starlette.types import Message, Receive, Scope, Send
try:
content_type_header = self.headers.get("Content-Type")
content_type, options = parse_options_header(content_type_header)
if content_type == b"multipart/form-data":
- multipart_parser = MultiPartParser(self.headers, self.stream())
- self._form = await multipart_parser.parse()
+ try:
+ multipart_parser = MultiPartParser(self.headers, self.stream())
+ self._form = await multipart_parser.parse()
+ except MultiPartException as exc:
+ if "app" in self.scope:
+ raise HTTPException(status_code=400, detail=exc.message)
+ raise exc
elif content_type == b"application/x-www-form-urlencoded":
form_parser = FormParser(self.headers, self.stream())
self._form = await form_parser.parse()
+import warnings
+
import pytest
-from starlette.exceptions import ExceptionMiddleware, HTTPException
+from starlette.exceptions import HTTPException
+from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute
assert repr(CustomHTTPException(500, detail="Something custom")) == (
"CustomHTTPException(status_code=500, detail='Something custom')"
)
+
+
+def test_exception_middleware_deprecation() -> None:
+ # this test should be removed once the deprecation shim is removed
+ with pytest.warns(DeprecationWarning):
+ from starlette.exceptions import ExceptionMiddleware # noqa: F401
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ import starlette.exceptions
+
+ with pytest.warns(DeprecationWarning):
+ starlette.exceptions.ExceptionMiddleware
import os
import typing
+from contextlib import nullcontext as does_not_raise
import pytest
-from starlette.formparsers import UploadFile, _user_safe_decode
+from starlette.applications import Starlette
+from starlette.formparsers import MultiPartException, UploadFile, _user_safe_decode
from starlette.requests import Request
from starlette.responses import JSONResponse
+from starlette.routing import Mount
+from starlette.testclient import TestClient
class ForceMultipartDict(dict):
assert result == "abc"
-def test_missing_boundary_parameter(test_client_factory):
+@pytest.mark.parametrize(
+ "app,expectation",
+ [
+ (app, pytest.raises(MultiPartException)),
+ (Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
+ ],
+)
+def test_missing_boundary_parameter(
+ app, expectation, test_client_factory: typing.Callable[..., TestClient]
+) -> None:
client = test_client_factory(app)
- with pytest.raises(KeyError, match="boundary"):
- client.post(
+ with expectation:
+ res = client.post(
"/",
data=(
# file
),
headers={"Content-Type": "multipart/form-data; charset=utf-8"},
)
+ assert res.status_code == 400
+ assert res.text == "Missing boundary in multipart."