]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
avoid direct use of collapse_excgroups
authorThomas Grainger <tagrain@gmail.com>
Sun, 29 Dec 2024 14:23:47 +0000 (14:23 +0000)
committerThomas Grainger <tagrain@gmail.com>
Sun, 29 Dec 2024 14:23:47 +0000 (14:23 +0000)
starlette/_utils.py
starlette/middleware/base.py
starlette/middleware/wsgi.py
starlette/responses.py
tests/middleware/test_wsgi.py

index e9325016310ee5a786ea39047bc8011c5471ac01..ca28c6fa50f3354773ac89e5c54f4e3c8072d942 100644 (file)
@@ -4,7 +4,9 @@ import asyncio
 import functools
 import sys
 import typing
-from contextlib import contextmanager
+from contextlib import asynccontextmanager, contextmanager
+
+import anyio.abc
 
 from starlette.types import Scope
 
@@ -95,6 +97,13 @@ def collapse_excgroups() -> typing.Generator[None, None, None]:
             del exc, cause, tb, context
 
 
+@asynccontextmanager
+async def create_collapsing_task_group() -> typing.AsyncGenerator[anyio.abc.TaskGroup, None]:
+    with collapse_excgroups():
+        async with anyio.create_task_group() as tg:
+            yield tg
+
+
 def get_route_path(scope: Scope) -> str:
     path: str = scope["path"]
     root_path = scope.get("root_path", "")
index 6e37c6f603b537a08eae32070ee3545543b566a5..77404f790735e8ffb9ac8c448ff1be1c16df7229 100644 (file)
@@ -4,7 +4,7 @@ import typing
 
 import anyio
 
-from starlette._utils import collapse_excgroups
+from starlette._utils import create_collapsing_task_group
 from starlette.requests import ClientDisconnect, Request
 from starlette.responses import AsyncContentStream, Response
 from starlette.types import ASGIApp, Message, Receive, Scope, Send
@@ -173,8 +173,8 @@ class BaseHTTPMiddleware:
             return response
 
         send_stream, recv_stream = anyio.create_memory_object_stream[Message]()
-        with recv_stream, send_stream, collapse_excgroups():
-            async with anyio.create_task_group() as task_group:
+        with recv_stream, send_stream:
+            async with create_collapsing_task_group() as task_group:
                 response = await self.dispatch_func(request, call_next)
                 await response(scope, wrapped_receive, send)
                 response_sent.set()
index 6e0a3fae6c176485e506b498200ac22c499c6c50..3e9ad0296e21248c82520793e90b9a328c38e562 100644 (file)
@@ -9,6 +9,7 @@ import warnings
 import anyio
 from anyio.abc import ObjectReceiveStream, ObjectSendStream
 
+from starlette._utils import create_collapsing_task_group
 from starlette.types import Receive, Scope, Send
 
 warnings.warn(
@@ -102,7 +103,7 @@ class WSGIResponder:
             more_body = message.get("more_body", False)
         environ = build_environ(self.scope, body)
 
-        async with anyio.create_task_group() as task_group:
+        async with create_collapsing_task_group() as task_group:
             task_group.start_soon(self.sender, send)
             async with self.stream_send:
                 await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
index c522e7f23b9710affc0f38fcd358790898f2eef6..1f8b87bea85664da3c1fa018697a9459b6adc537 100644 (file)
@@ -18,7 +18,7 @@ from urllib.parse import quote
 import anyio
 import anyio.to_thread
 
-from starlette._utils import collapse_excgroups
+from starlette._utils import create_collapsing_task_group
 from starlette.background import BackgroundTask
 from starlette.concurrency import iterate_in_threadpool
 from starlette.datastructures import URL, Headers, MutableHeaders
@@ -259,15 +259,14 @@ class StreamingResponse(Response):
             except OSError:
                 raise ClientDisconnect()
         else:
-            with collapse_excgroups():
-                async with anyio.create_task_group() as task_group:
+            async with create_collapsing_task_group() as task_group:
 
-                    async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
-                        await func()
-                        task_group.cancel_scope.cancel()
+                async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
+                    await func()
+                    task_group.cancel_scope.cancel()
 
-                    task_group.start_soon(wrap, partial(self.stream_response, send))
-                    await wrap(partial(self.listen_for_disconnect, receive))
+                task_group.start_soon(wrap, partial(self.stream_response, send))
+                await wrap(partial(self.listen_for_disconnect, receive))
 
         if self.background is not None:
             await self.background()
index 3511c89c9623a3d6366d1551c52c662aa63c014d..418f0946ffb1bcf00c6c3175707395f4417cc170 100644 (file)
@@ -4,7 +4,6 @@ from typing import Any, Callable
 
 import pytest
 
-from starlette._utils import collapse_excgroups
 from starlette.middleware.wsgi import WSGIMiddleware, build_environ
 from tests.types import TestClientFactory
 
@@ -86,7 +85,7 @@ def test_wsgi_exception(test_client_factory: TestClientFactory) -> None:
     # The HTTP protocol implementations would catch this error and return 500.
     app = WSGIMiddleware(raise_exception)
     client = test_client_factory(app)
-    with pytest.raises(RuntimeError), collapse_excgroups():
+    with pytest.raises(RuntimeError):
         client.get("/")