From: Thomas Grainger Date: Sun, 29 Dec 2024 14:23:47 +0000 (+0000) Subject: avoid direct use of collapse_excgroups X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=02c221f346ca2f007e843e48f85e854e24ba87be;p=thirdparty%2Fstarlette.git avoid direct use of collapse_excgroups --- diff --git a/starlette/_utils.py b/starlette/_utils.py index e9325016..ca28c6fa 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -4,7 +4,9 @@ import asyncio import functools import sys import typing -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager + +import anyio.abc from starlette.types import Scope @@ -95,6 +97,13 @@ def collapse_excgroups() -> typing.Generator[None, None, None]: del exc, cause, tb, context +@asynccontextmanager +async def create_collapsing_task_group() -> typing.AsyncGenerator[anyio.abc.TaskGroup, None]: + with collapse_excgroups(): + async with anyio.create_task_group() as tg: + yield tg + + def get_route_path(scope: Scope) -> str: path: str = scope["path"] root_path = scope.get("root_path", "") diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 6e37c6f6..77404f79 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -4,7 +4,7 @@ import typing import anyio -from starlette._utils import collapse_excgroups +from starlette._utils import create_collapsing_task_group from starlette.requests import ClientDisconnect, Request from starlette.responses import AsyncContentStream, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -173,8 +173,8 @@ class BaseHTTPMiddleware: return response send_stream, recv_stream = anyio.create_memory_object_stream[Message]() - with recv_stream, send_stream, collapse_excgroups(): - async with anyio.create_task_group() as task_group: + with recv_stream, send_stream: + async with create_collapsing_task_group() as task_group: response = await self.dispatch_func(request, call_next) await response(scope, wrapped_receive, send) response_sent.set() diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 6e0a3fae..3e9ad029 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -9,6 +9,7 @@ import warnings import anyio from anyio.abc import ObjectReceiveStream, ObjectSendStream +from starlette._utils import create_collapsing_task_group from starlette.types import Receive, Scope, Send warnings.warn( @@ -102,7 +103,7 @@ class WSGIResponder: more_body = message.get("more_body", False) environ = build_environ(self.scope, body) - async with anyio.create_task_group() as task_group: + async with create_collapsing_task_group() as task_group: task_group.start_soon(self.sender, send) async with self.stream_send: await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) diff --git a/starlette/responses.py b/starlette/responses.py index c522e7f2..1f8b87be 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -18,7 +18,7 @@ from urllib.parse import quote import anyio import anyio.to_thread -from starlette._utils import collapse_excgroups +from starlette._utils import create_collapsing_task_group from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, Headers, MutableHeaders @@ -259,15 +259,14 @@ class StreamingResponse(Response): except OSError: raise ClientDisconnect() else: - with collapse_excgroups(): - async with anyio.create_task_group() as task_group: + async with create_collapsing_task_group() as task_group: - async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: - await func() - task_group.cancel_scope.cancel() + async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() - task_group.start_soon(wrap, partial(self.stream_response, send)) - await wrap(partial(self.listen_for_disconnect, receive)) + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) if self.background is not None: await self.background() diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 3511c89c..418f0946 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -4,7 +4,6 @@ from typing import Any, Callable import pytest -from starlette._utils import collapse_excgroups from starlette.middleware.wsgi import WSGIMiddleware, build_environ from tests.types import TestClientFactory @@ -86,7 +85,7 @@ def test_wsgi_exception(test_client_factory: TestClientFactory) -> None: # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) client = test_client_factory(app) - with pytest.raises(RuntimeError), collapse_excgroups(): + with pytest.raises(RuntimeError): client.get("/")