]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Fixed import error when exceptiongroup isn't available (#2231)
authorAlex Grönholm <alex.gronholm@nextday.fi>
Fri, 25 Aug 2023 18:29:44 +0000 (21:29 +0300)
committerGitHub <noreply@github.com>
Fri, 25 Aug 2023 18:29:44 +0000 (20:29 +0200)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
requirements.txt
starlette/_utils.py
starlette/middleware/base.py
tests/middleware/test_wsgi.py

index 65e240832f1e4a42622d1aef727ac1b7c3fd8aac..4794d120e750c20ff75c1912069d227601462095 100644 (file)
@@ -12,8 +12,8 @@ types-contextvars==2.4.7.2
 types-PyYAML==6.0.12.10
 types-dataclasses==0.6.6
 pytest==7.4.0
-trio==0.22.1
-anyio@git+https://github.com/agronholm/anyio.git
+trio==0.21.0
+anyio==3.7.1
 
 # Documentation
 mkdocs==1.4.3
index f06dd557cedc61825c31056dcb8f3a3e7756b88c..26854f3d475f0bb398e08604ecb9c99220a1b0ab 100644 (file)
@@ -2,12 +2,20 @@ import asyncio
 import functools
 import sys
 import typing
+from contextlib import contextmanager
 
 if sys.version_info >= (3, 10):  # pragma: no cover
     from typing import TypeGuard
 else:  # pragma: no cover
     from typing_extensions import TypeGuard
 
+has_exceptiongroups = True
+if sys.version_info < (3, 11):  # pragma: no cover
+    try:
+        from exceptiongroup import BaseExceptionGroup
+    except ImportError:
+        has_exceptiongroups = False
+
 T = typing.TypeVar("T")
 AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
 
@@ -66,3 +74,15 @@ class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
     async def __aexit__(self, *args: typing.Any) -> typing.Union[None, bool]:
         await self.entered.close()
         return None
+
+
+@contextmanager
+def collapse_excgroups() -> typing.Generator[None, None, None]:
+    try:
+        yield
+    except BaseException as exc:
+        if has_exceptiongroups:
+            while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
+                exc = exc.exceptions[0]  # pragma: no cover
+
+        raise exc
index ee99ee6cb47dbdcb9ce74ab5db64a821d2d0951c..ad3ffcfeefc0f20f7b95da0423c0d514ff0dbb86 100644 (file)
@@ -1,18 +1,14 @@
-import sys
 import typing
-from contextlib import contextmanager
 
 import anyio
 from anyio.abc import ObjectReceiveStream, ObjectSendStream
 
+from starlette._utils import collapse_excgroups
 from starlette.background import BackgroundTask
 from starlette.requests import ClientDisconnect, Request
 from starlette.responses import ContentStream, Response, StreamingResponse
 from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
-if sys.version_info < (3, 11):  # pragma: no cover
-    from exceptiongroup import BaseExceptionGroup
-
 RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
 DispatchFunction = typing.Callable[
     [Request, RequestResponseEndpoint], typing.Awaitable[Response]
@@ -20,17 +16,6 @@ DispatchFunction = typing.Callable[
 T = typing.TypeVar("T")
 
 
-@contextmanager
-def _convert_excgroups() -> typing.Generator[None, None, None]:
-    try:
-        yield
-    except BaseException as exc:
-        while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
-            exc = exc.exceptions[0]
-
-        raise exc
-
-
 class _CachedRequest(Request):
     """
     If the user calls Request.body() from their dispatch function
@@ -201,7 +186,7 @@ class BaseHTTPMiddleware:
             response.raw_headers = message["headers"]
             return response
 
-        with _convert_excgroups():
+        with collapse_excgroups():
             async with anyio.create_task_group() as task_group:
                 response = await self.dispatch_func(request, call_next)
                 await response(scope, wrapped_receive, send)
index ad39754036ddd059b3760367f9a9d5f78e1d1858..fe527e373d62f0dd732f6798ca324039a43d3245 100644 (file)
@@ -2,11 +2,9 @@ import sys
 
 import pytest
 
+from starlette._utils import collapse_excgroups
 from starlette.middleware.wsgi import WSGIMiddleware, build_environ
 
-if sys.version_info < (3, 11):  # pragma: no cover
-    from exceptiongroup import ExceptionGroup
-
 
 def hello_world(environ, start_response):
     status = "200 OK"
@@ -69,12 +67,9 @@ def test_wsgi_exception(test_client_factory):
     # The HTTP protocol implementations would catch this error and return 500.
     app = WSGIMiddleware(raise_exception)
     client = test_client_factory(app)
-    with pytest.raises(ExceptionGroup) as exc:
+    with pytest.raises(RuntimeError), collapse_excgroups():
         client.get("/")
 
-    assert len(exc.value.exceptions) == 1
-    assert isinstance(exc.value.exceptions[0], RuntimeError)
-
 
 def test_wsgi_exc_info(test_client_factory):
     # Note that we're testing the WSGI app directly here.