import functools
import sys
import typing
-from contextlib import contextmanager
+from contextlib import asynccontextmanager, contextmanager
+
+import anyio.abc
from starlette.types import Scope
del exc, cause, tb, context
+@asynccontextmanager
+async def create_collapsing_task_group() -> typing.AsyncGenerator[anyio.abc.TaskGroup, None]:
+ with collapse_excgroups():
+ async with anyio.create_task_group() as tg:
+ yield tg
+
+
def get_route_path(scope: Scope) -> str:
path: str = scope["path"]
root_path = scope.get("root_path", "")
import anyio
-from starlette._utils import collapse_excgroups
+from starlette._utils import create_collapsing_task_group
from starlette.requests import ClientDisconnect, Request
from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
return response
send_stream, recv_stream = anyio.create_memory_object_stream[Message]()
- with recv_stream, send_stream, collapse_excgroups():
- async with anyio.create_task_group() as task_group:
+ with recv_stream, send_stream:
+ async with create_collapsing_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream
+from starlette._utils import create_collapsing_task_group
from starlette.types import Receive, Scope, Send
warnings.warn(
more_body = message.get("more_body", False)
environ = build_environ(self.scope, body)
- async with anyio.create_task_group() as task_group:
+ async with create_collapsing_task_group() as task_group:
task_group.start_soon(self.sender, send)
async with self.stream_send:
await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
import anyio
import anyio.to_thread
-from starlette._utils import collapse_excgroups
+from starlette._utils import create_collapsing_task_group
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders
except OSError:
raise ClientDisconnect()
else:
- with collapse_excgroups():
- async with anyio.create_task_group() as task_group:
+ async with create_collapsing_task_group() as task_group:
- async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
- await func()
- task_group.cancel_scope.cancel()
+ async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
+ await func()
+ task_group.cancel_scope.cancel()
- task_group.start_soon(wrap, partial(self.stream_response, send))
- await wrap(partial(self.listen_for_disconnect, receive))
+ task_group.start_soon(wrap, partial(self.stream_response, send))
+ await wrap(partial(self.listen_for_disconnect, receive))
if self.background is not None:
await self.background()
import pytest
-from starlette._utils import collapse_excgroups
from starlette.middleware.wsgi import WSGIMiddleware, build_environ
from tests.types import TestClientFactory
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(raise_exception)
client = test_client_factory(app)
- with pytest.raises(RuntimeError), collapse_excgroups():
+ with pytest.raises(RuntimeError):
client.get("/")