--- /dev/null
+from starlette.applications import Starlette
+from starlette.exceptions import HTTPException, WebSocketException
+from starlette.routing import WebSocketRoute
+from starlette.websockets import WebSocket
+
+
+async def websocket_endpoint(websocket: WebSocket):
+ await websocket.accept()
+ raise WebSocketException(code=1001)
+ raise HTTPException(400)
+
+
+app = Starlette(routes=[WebSocketRoute("/ws", websocket_endpoint)])
from __future__ import annotations
import typing
+import warnings
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
-from starlette.exceptions import HTTPException
+from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
-from starlette.types import (
- ASGIApp,
- ExceptionHandler,
- HTTPExceptionHandler,
- Message,
- Receive,
- Scope,
- Send,
- WebSocketExceptionHandler,
-)
+from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
from starlette.websockets import WebSocket
ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
- exception_handlers: ExceptionHandlers
- status_handlers: StatusHandlers
- try:
- exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
- except KeyError:
- exception_handlers, status_handlers = {}, {}
-
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
response_started = False
+ websocket_accepted = False
+ print("websocket_accepted", websocket_accepted)
async def sender(message: Message) -> None:
nonlocal response_started
+ nonlocal websocket_accepted
- if message["type"] == "http.response.start":
+ if message["type"] in ("http.response.start", "websocket.http.response.start"):
response_started = True
+ elif message["type"] == "websocket.accept":
+ websocket_accepted = True
+ print(f"message: {message}")
await send(message)
try:
await app(scope, receive, sender)
except Exception as exc:
- handler = None
+ exception_handlers: ExceptionHandlers
+ status_handlers: StatusHandlers
+ try:
+ exception_handlers, status_handlers = scope["starlette.exception_handlers"]
+ except KeyError:
+ exception_handlers, status_handlers = {}, {}
+ handler = None
if isinstance(exc, HTTPException):
handler = status_handlers.get(exc.status_code)
raise exc
if response_started:
- msg = "Caught handled exception, but response already started."
- raise RuntimeError(msg) from exc
-
- if scope["type"] == "http":
- nonlocal conn
- handler = typing.cast(HTTPExceptionHandler, handler)
- conn = typing.cast(Request, conn)
- if is_async_callable(handler):
- response = await handler(conn, exc)
- else:
- response = await run_in_threadpool(handler, conn, exc)
+ raise RuntimeError("Caught handled exception, but response already started.") from exc
+
+ print("run before the conditional", websocket_accepted)
+ if not websocket_accepted and isinstance(exc, WebSocketException):
+ warnings.warn("WebSocketException used before the websocket connection was accepted.", UserWarning)
+
+ if is_async_callable(handler):
+ response = await handler(conn, exc)
+ else:
+ response = await run_in_threadpool(handler, conn, exc) # type: ignore
+ if response is not None:
await response(scope, receive, sender)
- elif scope["type"] == "websocket":
- handler = typing.cast(WebSocketExceptionHandler, handler)
- conn = typing.cast(WebSocket, conn)
- if is_async_callable(handler):
- await handler(conn, exc)
- else:
- await run_in_threadpool(handler, conn, exc)
return wrapped_app
import typing
-from starlette._exception_handler import (
- ExceptionHandlers,
- StatusHandlers,
- wrap_app_handling_exceptions,
-)
+from starlette._exception_handler import ExceptionHandlers, StatusHandlers, wrap_app_handling_exceptions
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"):
- await self.app(scope, receive, send)
- return
+ return await self.app(scope, receive, send)
- scope["starlette.exception_handlers"] = (
- self._exception_handlers,
- self._status_handlers,
- )
+ scope["starlette.exception_handlers"] = (self._exception_handlers, self._status_handlers)
conn: Request | WebSocket
if scope["type"] == "http":
from pathlib import Path
from typing import AsyncGenerator, AsyncIterator, Generator
-import anyio
+import anyio.from_thread
import pytest
from starlette import status
def test_websocket_raise_websocket_exception(client: TestClient) -> None:
with client.websocket_connect("/ws-raise-websocket") as session:
response = session.receive()
- assert response == {
- "type": "websocket.close",
- "code": status.WS_1003_UNSUPPORTED_DATA,
- "reason": "",
- }
+ assert response == {"type": "websocket.close", "code": status.WS_1003_UNSUPPORTED_DATA, "reason": ""}
def test_websocket_raise_custom_exception(client: TestClient) -> None:
with client.websocket_connect("/ws-raise-custom") as session:
response = session.receive()
- assert response == {
- "type": "websocket.close",
- "code": status.WS_1013_TRY_AGAIN_LATER,
- "reason": "",
- }
+ assert response == {"type": "websocket.close", "code": status.WS_1013_TRY_AGAIN_LATER, "reason": ""}
def test_middleware(test_client_factory: TestClientFactory) -> None:
]
),
),
- Host(
- "{subdomain}.example.org",
- app=Router(routes=[Route("/", endpoint=custom_subdomain)]),
- ),
+ Host("{subdomain}.example.org", app=Router(routes=[Route("/", endpoint=custom_subdomain)])),
]
with open(path, "w") as file:
file.write("<file content>")
- app = Starlette(
- routes=[
- Mount("/static", StaticFiles(directory=tmpdir)),
- ]
- )
+ app = Starlette(routes=[Mount("/static", StaticFiles(directory=tmpdir))])
client = test_client_factory(app)
async def homepage(request: Request) -> None:
raise RuntimeError()
- app = Starlette(
- routes=[
- Route("/", homepage),
- ],
- )
+ app = Starlette(routes=[Route("/", homepage)])
app.debug = True
client = test_client_factory(app, raise_server_exceptions=False)
async def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, World!")
- app = Starlette(
- routes=[
- Route("/", endpoint=homepage),
- ]
- )
+ app = Starlette(routes=[Route("/", endpoint=homepage)])
client = test_client_factory(app)
response = client.get("/")
await session.send_text("Hello, world!")
await session.close()
- app = Starlette(
- routes=[
- WebSocketRoute("/ws", endpoint=websocket_endpoint),
- ]
- )
+ app = Starlette(routes=[WebSocketRoute("/ws", endpoint=websocket_endpoint)])
client = test_client_factory(app)
with client.websocket_connect("/ws") as session:
cleanup_complete = True
with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
- app = Starlette(
- on_startup=[run_startup],
- on_shutdown=[run_cleanup],
- )
+ app = Starlette(on_startup=[run_startup], on_shutdown=[run_cleanup])
assert not startup_complete
assert not cleanup_complete
Route("/with_headers", endpoint=with_headers),
Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
- Route(
- "/consume_body_in_endpoint_and_handler",
- endpoint=read_body_and_raise_exc,
- methods=["POST"],
- ),
+ Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]),
]
)
pass # pragma: no cover
-def test_handled_exc_after_response(
- test_client_factory: TestClientFactory,
- client: TestClient,
-) -> None:
+def test_handled_exc_after_response(test_client_factory: TestClientFactory, client: TestClient) -> None:
# A 406 HttpException is raised *after* the response has already been sent.
# The exception middleware should raise a RuntimeError.
with pytest.raises(RuntimeError):
def test_force_500_response(test_client_factory: TestClientFactory) -> None:
- # use a sentinal variable to make sure we actually
+ # use a sentinel variable to make sure we actually
# make it into the endpoint and don't get a 500
# from an incorrect ASGI app signature or something
called = False