]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Fix unclosed 'MemoryObjectReceiveStream' upon exception in 'BaseHTTPMiddleware' child...
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Sun, 29 Dec 2024 12:32:00 +0000 (13:32 +0100)
committerGitHub <noreply@github.com>
Sun, 29 Dec 2024 12:32:00 +0000 (13:32 +0100)
pyproject.toml
starlette/middleware/base.py
starlette/testclient.py

index 50a53caf69e6b7d7b0a92cc055de06a4c2d34bf9..95f195c50cd82d0acf872f947f2afa74c1d5dcf9 100644 (file)
@@ -84,8 +84,6 @@ filterwarnings = [
     "ignore: starlette.middleware.wsgi is deprecated and will be removed in a future release.*:DeprecationWarning",
     "ignore: Async generator 'starlette.requests.Request.stream' was garbage collected before it had been exhausted.*:ResourceWarning",
     "ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning",
-    # TODO: This warning appeared when we bumped anyio to 4.4.0.
-    "ignore: Unclosed .MemoryObject(Send|Receive)Stream.:ResourceWarning",
 ]
 
 [tool.coverage.run]
index f51b13f733a9739a0ebd7111b82096bf576a8b87..6e37c6f603b537a08eae32070ee3545543b566a5 100644 (file)
@@ -3,7 +3,6 @@ from __future__ import annotations
 import typing
 
 import anyio
-from anyio.abc import ObjectReceiveStream, ObjectSendStream
 
 from starlette._utils import collapse_excgroups
 from starlette.requests import ClientDisconnect, Request
@@ -107,9 +106,6 @@ class BaseHTTPMiddleware:
 
         async def call_next(request: Request) -> Response:
             app_exc: Exception | None = None
-            send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
-            recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
-            send_stream, recv_stream = anyio.create_memory_object_stream()
 
             async def receive_or_disconnect() -> Message:
                 if response_sent.is_set():
@@ -130,10 +126,6 @@ class BaseHTTPMiddleware:
 
                 return message
 
-            async def close_recv_stream_on_response_sent() -> None:
-                await response_sent.wait()
-                recv_stream.close()
-
             async def send_no_error(message: Message) -> None:
                 try:
                     await send_stream.send(message)
@@ -144,13 +136,12 @@ class BaseHTTPMiddleware:
             async def coro() -> None:
                 nonlocal app_exc
 
-                async with send_stream:
+                with send_stream:
                     try:
                         await self.app(scope, receive_or_disconnect, send_no_error)
                     except Exception as exc:
                         app_exc = exc
 
-            task_group.start_soon(close_recv_stream_on_response_sent)
             task_group.start_soon(coro)
 
             try:
@@ -166,14 +157,13 @@ class BaseHTTPMiddleware:
             assert message["type"] == "http.response.start"
 
             async def body_stream() -> typing.AsyncGenerator[bytes, None]:
-                async with recv_stream:
-                    async for message in recv_stream:
-                        assert message["type"] == "http.response.body"
-                        body = message.get("body", b"")
-                        if body:
-                            yield body
-                        if not message.get("more_body", False):
-                            break
+                async for message in recv_stream:
+                    assert message["type"] == "http.response.body"
+                    body = message.get("body", b"")
+                    if body:
+                        yield body
+                    if not message.get("more_body", False):
+                        break
 
                 if app_exc is not None:
                     raise app_exc
@@ -182,11 +172,13 @@ class BaseHTTPMiddleware:
             response.raw_headers = message["headers"]
             return response
 
-        with collapse_excgroups():
+        send_stream, recv_stream = anyio.create_memory_object_stream[Message]()
+        with recv_stream, send_stream, collapse_excgroups():
             async with anyio.create_task_group() as task_group:
                 response = await self.dispatch_func(request, call_next)
                 await response(scope, wrapped_receive, send)
                 response_sent.set()
+                recv_stream.close()
 
     async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
         raise NotImplementedError()  # pragma: no cover
index 9a0abbd7b506d2d13644715c6438306576f90763..4f9788feb8828e70bba98180ee58153815032bb3 100644 (file)
@@ -14,7 +14,6 @@ from urllib.parse import unquote, urljoin
 import anyio
 import anyio.abc
 import anyio.from_thread
-from anyio.abc import ObjectReceiveStream, ObjectSendStream
 from anyio.streams.stapled import StapledObjectStream
 
 from starlette._utils import is_async_callable
@@ -658,12 +657,12 @@ class TestClient(httpx.Client):
             def reset_portal() -> None:
                 self.portal = None
 
-            send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None]
-            receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None]
-            send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
-            receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
-            send1, receive1 = anyio.create_memory_object_stream(math.inf)
-            send2, receive2 = anyio.create_memory_object_stream(math.inf)
+            send1, receive1 = anyio.create_memory_object_stream[
+                typing.Union[typing.MutableMapping[str, typing.Any], None]
+            ](math.inf)
+            send2, receive2 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]](math.inf)
+            for channel in (send1, send2, receive1, receive2):
+                stack.callback(channel.close)
             self.stream_send = StapledObjectStream(send1, receive1)
             self.stream_receive = StapledObjectStream(send2, receive2)
             self.task = portal.start_task_soon(self.lifespan)
@@ -711,12 +710,11 @@ class TestClient(httpx.Client):
                 self.task.result()
             return message
 
-        async with self.stream_send, self.stream_receive:
-            await self.stream_receive.send({"type": "lifespan.shutdown"})
-            message = await receive()
-            assert message["type"] in (
-                "lifespan.shutdown.complete",
-                "lifespan.shutdown.failed",
-            )
-            if message["type"] == "lifespan.shutdown.failed":
-                await receive()
+        await self.stream_receive.send({"type": "lifespan.shutdown"})
+        message = await receive()
+        assert message["type"] in (
+            "lifespan.shutdown.complete",
+            "lifespan.shutdown.failed",
+        )
+        if message["type"] == "lifespan.shutdown.failed":
+            await receive()