]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Upgraded to AnyIO 4.0 (#2211)
authorAlex Grönholm <alex.gronholm@nextday.fi>
Sun, 23 Jul 2023 06:26:35 +0000 (09:26 +0300)
committerGitHub <noreply@github.com>
Sun, 23 Jul 2023 06:26:35 +0000 (06:26 +0000)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
requirements.txt
starlette/middleware/base.py
starlette/middleware/wsgi.py
starlette/testclient.py
tests/middleware/test_wsgi.py
tests/test_websockets.py

index 6d044c13defe977c9f267ee4bdaa3352dfe72bd1..65e240832f1e4a42622d1aef727ac1b7c3fd8aac 100644 (file)
@@ -12,7 +12,8 @@ types-contextvars==2.4.7.2
 types-PyYAML==6.0.12.10
 types-dataclasses==0.6.6
 pytest==7.4.0
-trio==0.21.0
+trio==0.22.1
+anyio@git+https://github.com/agronholm/anyio.git
 
 # Documentation
 mkdocs==1.4.3
index 170a805a758452f5ccc5b11cb6b96987f9cdbdcd..ee99ee6cb47dbdcb9ce74ab5db64a821d2d0951c 100644 (file)
@@ -1,12 +1,18 @@
+import sys
 import typing
+from contextlib import contextmanager
 
 import anyio
+from anyio.abc import ObjectReceiveStream, ObjectSendStream
 
 from starlette.background import BackgroundTask
 from starlette.requests import ClientDisconnect, Request
 from starlette.responses import ContentStream, Response, StreamingResponse
 from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
+if sys.version_info < (3, 11):  # pragma: no cover
+    from exceptiongroup import BaseExceptionGroup
+
 RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
 DispatchFunction = typing.Callable[
     [Request, RequestResponseEndpoint], typing.Awaitable[Response]
@@ -14,6 +20,17 @@ DispatchFunction = typing.Callable[
 T = typing.TypeVar("T")
 
 
+@contextmanager
+def _convert_excgroups() -> typing.Generator[None, None, None]:
+    try:
+        yield
+    except BaseException as exc:
+        while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
+            exc = exc.exceptions[0]
+
+        raise exc
+
+
 class _CachedRequest(Request):
     """
     If the user calls Request.body() from their dispatch function
@@ -107,6 +124,8 @@ class BaseHTTPMiddleware:
 
         async def call_next(request: Request) -> Response:
             app_exc: typing.Optional[Exception] = None
+            send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
+            recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
             send_stream, recv_stream = anyio.create_memory_object_stream()
 
             async def receive_or_disconnect() -> Message:
@@ -182,10 +201,11 @@ class BaseHTTPMiddleware:
             response.raw_headers = message["headers"]
             return response
 
-        async with anyio.create_task_group() as task_group:
-            response = await self.dispatch_func(request, call_next)
-            await response(scope, wrapped_receive, send)
-            response_sent.set()
+        with _convert_excgroups():
+            async with anyio.create_task_group() as task_group:
+                response = await self.dispatch_func(request, call_next)
+                await response(scope, wrapped_receive, send)
+                response_sent.set()
 
     async def dispatch(
         self, request: Request, call_next: RequestResponseEndpoint
index 9dbd065284aaf903ef0153a20027f66e2530eb65..d4a117cacfa97f68a1d409a808dd5aaaa4ee2601 100644 (file)
@@ -5,6 +5,7 @@ import typing
 import warnings
 
 import anyio
+from anyio.abc import ObjectReceiveStream, ObjectSendStream
 
 from starlette.types import Receive, Scope, Send
 
@@ -72,6 +73,9 @@ class WSGIMiddleware:
 
 
 class WSGIResponder:
+    stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
+    stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
+
     def __init__(self, app: typing.Callable, scope: Scope) -> None:
         self.app = app
         self.scope = scope
index a91ad7bfc8f242d16269ded6f575eec088008fdb..c9ae97a0824aaf1c5fb821aa4bd75c3a62103851 100644 (file)
@@ -12,6 +12,7 @@ from urllib.parse import unquote, urljoin
 
 import anyio
 import anyio.from_thread
+from anyio.abc import ObjectReceiveStream, ObjectSendStream
 from anyio.streams.stapled import StapledObjectStream
 
 from starlette._utils import is_async_callable
@@ -737,12 +738,18 @@ class TestClient(httpx.Client):
             def reset_portal() -> None:
                 self.portal = None
 
-            self.stream_send = StapledObjectStream(
-                *anyio.create_memory_object_stream(math.inf)
-            )
-            self.stream_receive = StapledObjectStream(
-                *anyio.create_memory_object_stream(math.inf)
-            )
+            send1: ObjectSendStream[
+                typing.Optional[typing.MutableMapping[str, typing.Any]]
+            ]
+            receive1: ObjectReceiveStream[
+                typing.Optional[typing.MutableMapping[str, typing.Any]]
+            ]
+            send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
+            receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
+            send1, receive1 = anyio.create_memory_object_stream(math.inf)
+            send2, receive2 = anyio.create_memory_object_stream(math.inf)
+            self.stream_send = StapledObjectStream(send1, receive1)
+            self.stream_receive = StapledObjectStream(send2, receive2)
             self.task = portal.start_task_soon(self.lifespan)
             portal.call(self.wait_startup)
 
index bcb4cd6ff28751a4b44d1316a4e9d23f644a0f83..ad39754036ddd059b3760367f9a9d5f78e1d1858 100644 (file)
@@ -4,6 +4,9 @@ import pytest
 
 from starlette.middleware.wsgi import WSGIMiddleware, build_environ
 
+if sys.version_info < (3, 11):  # pragma: no cover
+    from exceptiongroup import ExceptionGroup
+
 
 def hello_world(environ, start_response):
     status = "200 OK"
@@ -66,9 +69,12 @@ def test_wsgi_exception(test_client_factory):
     # The HTTP protocol implementations would catch this error and return 500.
     app = WSGIMiddleware(raise_exception)
     client = test_client_factory(app)
-    with pytest.raises(RuntimeError):
+    with pytest.raises(ExceptionGroup) as exc:
         client.get("/")
 
+    assert len(exc.value.exceptions) == 1
+    assert isinstance(exc.value.exceptions[0], RuntimeError)
+
 
 def test_wsgi_exc_info(test_client_factory):
     # Note that we're testing the WSGI app directly here.
index c1ec1153e7df0e12b01dca58ab5ef3b6f06e2ea2..71bccd455d15117148eac8c4afcc724ea77b9ee7 100644 (file)
@@ -1,7 +1,9 @@
 import sys
+from typing import Any, MutableMapping
 
 import anyio
 import pytest
+from anyio.abc import ObjectReceiveStream, ObjectSendStream
 
 from starlette import status
 from starlette.types import Receive, Scope, Send
@@ -178,6 +180,8 @@ def test_websocket_iter_json(test_client_factory):
 
 
 def test_websocket_concurrency_pattern(test_client_factory):
+    stream_send: ObjectSendStream[MutableMapping[str, Any]]
+    stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
     stream_send, stream_receive = anyio.create_memory_object_stream()
 
     async def reader(websocket):